Text Classification
Transformers
ONNX
Safetensors
English
distilbert
intent-classification
multitask
iab
conversational-ai
adtech
calibrated-confidence
text-embeddings-inference
Instructions to use admesh/agentic-intent-classifier with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use admesh/agentic-intent-classifier with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="admesh/agentic-intent-classifier")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("admesh/agentic-intent-classifier", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import json | |
| import re | |
| from functools import lru_cache | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoModel, AutoTokenizer | |
| from config import ( | |
| IAB_RETRIEVAL_DEPTH_BONUS, | |
| IAB_RETRIEVAL_MODEL_MAX_LENGTH, | |
| IAB_RETRIEVAL_MODEL_NAME, | |
| IAB_RETRIEVAL_PREFIX_CONFIDENCE_THRESHOLDS, | |
| IAB_RETRIEVAL_TOP_K, | |
| IAB_TAXONOMY_EMBEDDINGS_PATH, | |
| IAB_TAXONOMY_NODES_PATH, | |
| IAB_TAXONOMY_VERSION, | |
| ensure_artifact_dirs, | |
| ) | |
| from iab_taxonomy import IabNode, get_iab_taxonomy, path_to_label | |
| RETRIEVAL_STOPWORDS = { | |
| "a", | |
| "an", | |
| "and", | |
| "are", | |
| "as", | |
| "at", | |
| "best", | |
| "buy", | |
| "for", | |
| "from", | |
| "how", | |
| "i", | |
| "in", | |
| "is", | |
| "it", | |
| "me", | |
| "my", | |
| "need", | |
| "of", | |
| "on", | |
| "or", | |
| "should", | |
| "the", | |
| "to", | |
| "tonight", | |
| "what", | |
| "which", | |
| "with", | |
| } | |
| GTE_QWEN_QUERY_INSTRUCTION = "Given a user query, retrieve the most relevant IAB content taxonomy category." | |
| def round_score(value: float) -> float: | |
| return round(float(value), 4) | |
| def _normalize_keyword(value: str) -> str: | |
| value = value.lower().replace("&", " and ") | |
| value = re.sub(r"[^a-z0-9]+", " ", value) | |
| return " ".join(value.split()) | |
| def _keyword_tokens(value: str) -> set[str]: | |
| return { | |
| token | |
| for token in _normalize_keyword(value).split() | |
| if token and token not in RETRIEVAL_STOPWORDS and len(token) > 1 | |
| } | |
| def _is_gte_qwen_model(model_name: str) -> bool: | |
| normalized = model_name.lower() | |
| return "gte-qwen" in normalized | |
| def _last_token_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
| left_padding = bool(torch.all(attention_mask[:, -1] == 1)) | |
| if left_padding: | |
| return last_hidden_state[:, -1] | |
| sequence_lengths = attention_mask.sum(dim=1) - 1 | |
| batch_indices = torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device) | |
| return last_hidden_state[batch_indices, sequence_lengths] | |
| def _node_keywords(node: IabNode) -> list[str]: | |
| keywords = {node.label, node.path_label} | |
| keywords.update(node.path) | |
| normalized = {_normalize_keyword(keyword) for keyword in keywords if keyword.strip()} | |
| return sorted(keyword for keyword in normalized if keyword) | |
| def _node_retrieval_text(node: IabNode) -> str: | |
| keywords = _node_keywords(node) | |
| parts = [ | |
| f"IAB category path: {node.path_label}", | |
| f"Canonical label: {node.label}", | |
| f"Tier depth: {node.level}", | |
| ] | |
| if len(node.path) > 1: | |
| parts.append(f"Parent path: {' > '.join(node.path[:-1])}") | |
| if keywords: | |
| parts.append(f"Keywords: {', '.join(keywords)}") | |
| return ". ".join(parts) | |
| def _serialize_node(node: IabNode) -> dict: | |
| return { | |
| "unique_id": node.unique_id, | |
| "parent_id": node.parent_id, | |
| "label": node.label, | |
| "path": list(node.path), | |
| "path_label": node.path_label, | |
| "level": node.level, | |
| "keywords": _node_keywords(node), | |
| "retrieval_text": _node_retrieval_text(node), | |
| } | |
| class LocalTextEmbedder: | |
| def __init__(self, model_name: str, max_length: int): | |
| self.model_name = model_name | |
| self.max_length = max_length | |
| self._tokenizer = None | |
| self._model = None | |
| self._batch_size = 32 | |
| self._device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self._is_gte_qwen = _is_gte_qwen_model(model_name) | |
| def tokenizer(self): | |
| if self._tokenizer is None: | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=self._is_gte_qwen, | |
| ) | |
| return self._tokenizer | |
| def model(self): | |
| if self._model is None: | |
| model_kwargs = {"trust_remote_code": self._is_gte_qwen} | |
| if self._device == "cuda": | |
| model_kwargs["torch_dtype"] = torch.float16 | |
| self._model = AutoModel.from_pretrained(self.model_name, **model_kwargs) | |
| self._model.to(self._device) | |
| self._model.eval() | |
| return self._model | |
| def encode_documents(self, texts: list[str], batch_size: int | None = None) -> torch.Tensor: | |
| return self._encode_texts(texts, batch_size=batch_size, treat_as_query=False) | |
| def encode_queries(self, texts: list[str], batch_size: int | None = None) -> torch.Tensor: | |
| return self._encode_texts(texts, batch_size=batch_size, treat_as_query=True) | |
| def _encode_texts( | |
| self, | |
| texts: list[str], | |
| batch_size: int | None = None, | |
| treat_as_query: bool = False, | |
| ) -> torch.Tensor: | |
| if not texts: | |
| return torch.empty(0, 0) | |
| effective_batch_size = batch_size or self._batch_size | |
| rows: list[torch.Tensor] = [] | |
| for start in range(0, len(texts), effective_batch_size): | |
| batch_texts = texts[start : start + effective_batch_size] | |
| if treat_as_query and self._is_gte_qwen: | |
| batch_texts = [ | |
| f"Instruct: {GTE_QWEN_QUERY_INSTRUCTION}\nQuery: {text}" | |
| for text in batch_texts | |
| ] | |
| inputs = self.tokenizer( | |
| batch_texts, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=self.max_length, | |
| ) | |
| inputs = {key: value.to(self._device) for key, value in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| hidden = outputs.last_hidden_state | |
| if self._is_gte_qwen: | |
| pooled = _last_token_pool(hidden, inputs["attention_mask"]) | |
| else: | |
| mask = inputs["attention_mask"].unsqueeze(-1) | |
| pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) | |
| rows.append(F.normalize(pooled.float(), p=2, dim=1).cpu()) | |
| return torch.cat(rows, dim=0) | |
| def get_iab_text_embedder() -> LocalTextEmbedder: | |
| return LocalTextEmbedder(IAB_RETRIEVAL_MODEL_NAME, IAB_RETRIEVAL_MODEL_MAX_LENGTH) | |
| def build_iab_taxonomy_embedding_index(batch_size: int = 32) -> dict: | |
| ensure_artifact_dirs() | |
| taxonomy = get_iab_taxonomy() | |
| nodes = [_serialize_node(node) for node in taxonomy.nodes] | |
| embedder = get_iab_text_embedder() | |
| embeddings = embedder.encode_documents([node["retrieval_text"] for node in nodes], batch_size=batch_size) | |
| IAB_TAXONOMY_NODES_PATH.write_text(json.dumps(nodes, indent=2, sort_keys=True) + "\n", encoding="utf-8") | |
| torch.save( | |
| { | |
| "model_name": embedder.model_name, | |
| "taxonomy_version": IAB_TAXONOMY_VERSION, | |
| "embedding_dim": int(embeddings.shape[1]), | |
| "node_count": len(nodes), | |
| "embeddings": embeddings, | |
| }, | |
| IAB_TAXONOMY_EMBEDDINGS_PATH, | |
| ) | |
| return { | |
| "taxonomy_version": IAB_TAXONOMY_VERSION, | |
| "model_name": embedder.model_name, | |
| "node_count": len(nodes), | |
| "embedding_dim": int(embeddings.shape[1]), | |
| "nodes_path": str(IAB_TAXONOMY_NODES_PATH), | |
| "embeddings_path": str(IAB_TAXONOMY_EMBEDDINGS_PATH), | |
| } | |
| class IabEmbeddingRetriever: | |
| def __init__(self): | |
| self.taxonomy = get_iab_taxonomy() | |
| self.embedder = get_iab_text_embedder() | |
| self._nodes: list[dict] | None = None | |
| self._embeddings: torch.Tensor | None = None | |
| def _load_index(self) -> bool: | |
| if self._nodes is not None and self._embeddings is not None: | |
| return True | |
| if not IAB_TAXONOMY_NODES_PATH.exists() or not IAB_TAXONOMY_EMBEDDINGS_PATH.exists(): | |
| return False | |
| nodes = json.loads(IAB_TAXONOMY_NODES_PATH.read_text(encoding="utf-8")) | |
| payload = torch.load(IAB_TAXONOMY_EMBEDDINGS_PATH, map_location="cpu") | |
| if payload.get("model_name") != IAB_RETRIEVAL_MODEL_NAME: | |
| return False | |
| if payload.get("taxonomy_version") != IAB_TAXONOMY_VERSION: | |
| return False | |
| embeddings = payload.get("embeddings") | |
| if not isinstance(embeddings, torch.Tensor): | |
| embeddings = torch.tensor(embeddings, dtype=torch.float32) | |
| if len(nodes) != embeddings.shape[0]: | |
| return False | |
| self._nodes = nodes | |
| self._embeddings = F.normalize(embeddings.float(), p=2, dim=1) | |
| return True | |
| def ready(self) -> bool: | |
| return self._load_index() | |
| def _score_to_confidence(score: float) -> float: | |
| return min(max((score + 1.0) / 2.0, 0.0), 1.0) | |
| def _candidate_from_index(self, score: float, index: int) -> dict: | |
| assert self._nodes is not None | |
| node = self._nodes[index] | |
| confidence = self._score_to_confidence(float(score)) | |
| adjusted_confidence = confidence + (IAB_RETRIEVAL_DEPTH_BONUS * max(int(node["level"]) - 1, 0)) | |
| return { | |
| "unique_id": node["unique_id"], | |
| "label": node["label"], | |
| "path": tuple(node["path"]), | |
| "path_label": node["path_label"], | |
| "level": int(node["level"]), | |
| "confidence": round_score(confidence), | |
| "adjusted_confidence": round_score(adjusted_confidence), | |
| "keywords": list(node.get("keywords", [])), | |
| } | |
| def _rerank_candidates(self, query_text: str, candidates: list[dict]) -> list[dict]: | |
| if not candidates: | |
| return [] | |
| query_normalized = _normalize_keyword(query_text) | |
| query_tokens = _keyword_tokens(query_text) | |
| reranked = [] | |
| for candidate in candidates: | |
| keyword_tokens = set() | |
| for keyword in candidate.get("keywords", []): | |
| keyword_tokens.update(_keyword_tokens(keyword)) | |
| token_overlap = len(query_tokens & keyword_tokens) | |
| path_overlap = len(query_tokens & _keyword_tokens(candidate["path_label"])) | |
| lexical_bonus = min(0.04, (0.008 * token_overlap) + (0.004 * path_overlap)) | |
| reranked.append( | |
| { | |
| **candidate, | |
| "token_overlap": token_overlap, | |
| "path_overlap": path_overlap, | |
| "lexical_bonus": round_score(lexical_bonus), | |
| "rerank_score": round_score(candidate["adjusted_confidence"] + lexical_bonus), | |
| } | |
| ) | |
| reranked.sort( | |
| key=lambda item: ( | |
| item["rerank_score"], | |
| item["adjusted_confidence"], | |
| item["confidence"], | |
| ), | |
| reverse=True, | |
| ) | |
| return reranked | |
| def _top_candidates_from_embedding(self, query_text: str, query_embedding: torch.Tensor) -> list[dict]: | |
| if not self._load_index(): | |
| return [] | |
| assert self._embeddings is not None | |
| scores = torch.mv(self._embeddings, query_embedding) | |
| top_k = min(IAB_RETRIEVAL_TOP_K, scores.shape[0]) | |
| top_scores, top_indices = torch.topk(scores, k=top_k) | |
| candidates = [self._candidate_from_index(score, index) for score, index in zip(top_scores.tolist(), top_indices.tolist())] | |
| return self._rerank_candidates(query_text, candidates) | |
| def _top_candidates(self, text: str) -> list[dict]: | |
| if not self._load_index(): | |
| return [] | |
| query_embedding = self.embedder.encode_queries([text])[0] | |
| return self._top_candidates_from_embedding(text, query_embedding) | |
| def _select_path(self, candidates: list[dict]) -> dict | None: | |
| if not candidates: | |
| return None | |
| top_candidate = candidates[0] | |
| top_path = tuple(top_candidate["path"]) | |
| top_margin = round_score( | |
| top_candidate["confidence"] - candidates[1]["confidence"] if len(candidates) > 1 else top_candidate["confidence"] | |
| ) | |
| prefix_support: dict[tuple[str, ...], float] = {} | |
| for depth in range(1, len(top_path) + 1): | |
| prefix = top_path[:depth] | |
| prefix_support[prefix] = max( | |
| candidate["confidence"] | |
| for candidate in candidates | |
| if tuple(candidate["path"][:depth]) == prefix | |
| ) | |
| selected_path: tuple[str, ...] | None = None | |
| selected_threshold = 0.0 | |
| for depth in range(1, len(top_path) + 1): | |
| threshold = IAB_RETRIEVAL_PREFIX_CONFIDENCE_THRESHOLDS.get(depth, 0.62) | |
| prefix = top_path[:depth] | |
| if prefix_support[prefix] >= threshold: | |
| selected_path = prefix | |
| selected_threshold = threshold | |
| continue | |
| break | |
| if selected_path is None: | |
| return None | |
| stopped_reason = "accepted" if selected_path == top_path else "parent_fallback" | |
| if len(top_path) > 1: | |
| ambiguous_sibling = any( | |
| tuple(candidate["path"][:-1]) == top_path[:-1] | |
| and (top_candidate["confidence"] - candidate["confidence"]) <= 0.03 | |
| for candidate in candidates[1:] | |
| ) | |
| if ambiguous_sibling: | |
| selected_path = top_path[:-1] | |
| selected_threshold = IAB_RETRIEVAL_PREFIX_CONFIDENCE_THRESHOLDS.get(len(selected_path), 0.62) | |
| stopped_reason = "ambiguous_sibling_parent_fallback" | |
| mapping_confidence = prefix_support[selected_path] | |
| return { | |
| "path": selected_path, | |
| "path_label": path_to_label(selected_path), | |
| "mapping_mode": "nearest_equivalent", | |
| "mapping_confidence": round_score(mapping_confidence), | |
| "confidence_threshold": round_score(selected_threshold), | |
| "top_candidate_confidence": round_score(top_candidate["confidence"]), | |
| "top_margin": top_margin, | |
| "stopped_reason": stopped_reason, | |
| } | |
| def predict(self, text: str) -> dict | None: | |
| candidates = self._top_candidates(text) | |
| return self._prediction_from_candidates(candidates) | |
| def _prediction_from_candidates(self, candidates: list[dict]) -> dict | None: | |
| selection = self._select_path(candidates) | |
| if selection is None: | |
| return None | |
| content = self.taxonomy.build_content_object( | |
| path=selection["path"], | |
| mapping_mode=selection["mapping_mode"], | |
| mapping_confidence=selection["mapping_confidence"], | |
| ) | |
| return { | |
| "label": selection["path_label"], | |
| "confidence": selection["mapping_confidence"], | |
| "raw_confidence": selection["top_candidate_confidence"], | |
| "confidence_threshold": selection["confidence_threshold"], | |
| "calibrated": False, | |
| "meets_confidence_threshold": True, | |
| "content": content, | |
| "path": selection["path"], | |
| "mapping_mode": selection["mapping_mode"], | |
| "mapping_confidence": selection["mapping_confidence"], | |
| "source": "embedding_retrieval", | |
| "retrieval_model_name": IAB_RETRIEVAL_MODEL_NAME, | |
| "stopped_reason": selection["stopped_reason"], | |
| "top_margin": selection["top_margin"], | |
| "top_candidates": [ | |
| { | |
| **candidate, | |
| "path": list(candidate["path"]), | |
| "keywords": candidate["keywords"][:12], | |
| } | |
| for candidate in candidates | |
| ], | |
| } | |
| def predict_batch(self, texts: list[str], batch_size: int | None = None) -> list[dict | None]: | |
| if not texts: | |
| return [] | |
| if not self._load_index(): | |
| return [None for _ in texts] | |
| query_embeddings = self.embedder.encode_queries(texts, batch_size=batch_size) | |
| return [ | |
| self._prediction_from_candidates(self._top_candidates_from_embedding(text, query_embedding)) | |
| for text, query_embedding in zip(texts, query_embeddings) | |
| ] | |
| def get_iab_embedding_retriever() -> IabEmbeddingRetriever: | |
| return IabEmbeddingRetriever() | |
| def predict_iab_content_retrieval(text: str) -> dict | None: | |
| retriever = get_iab_embedding_retriever() | |
| if not retriever.ready(): | |
| return None | |
| return retriever.predict(text) | |
| def predict_iab_content_retrieval_batch(texts: list[str], batch_size: int | None = None) -> list[dict | None]: | |
| retriever = get_iab_embedding_retriever() | |
| if not retriever.ready(): | |
| return [None for _ in texts] | |
| return retriever.predict_batch(texts, batch_size=batch_size) | |