# DEPENDENCIES from typing import Dict from typing import List from typing import Tuple from loguru import logger from typing import Optional from dataclasses import dataclass from config.threshold_config import Domain from models.model_manager import get_model_manager from config.threshold_config import interpolate_thresholds from config.threshold_config import get_threshold_for_domain @dataclass class DomainPrediction: """ Result of domain classification """ primary_domain : Domain secondary_domain : Optional[Domain] confidence : float domain_scores : Dict[str, float] class DomainClassifier: """ Classifies text into domains using zero-shot classification """ # Enhanced domain labels for zero-shot classification DOMAIN_LABELS = {Domain.ACADEMIC : ["academic paper", "research article", "scientific paper", "scholarly writing", "thesis", "dissertation", "academic research"], Domain.CREATIVE : ["creative writing", "fiction", "story", "narrative", "poetry", "literary work", "imaginative writing"], Domain.AI_ML : ["artificial intelligence", "machine learning", "neural networks", "data science", "AI research", "deep learning"], Domain.SOFTWARE_DEV : ["software development", "programming", "coding", "software engineering", "web development", "application development"], Domain.TECHNICAL_DOC : ["technical documentation", "user manual", "API documentation", "technical guide", "system documentation"], Domain.ENGINEERING : ["engineering document", "technical design", "engineering analysis", "mechanical engineering", "electrical engineering"], Domain.SCIENCE : ["scientific research", "physics", "chemistry", "biology", "scientific study", "experimental results"], Domain.BUSINESS : ["business document", "corporate communication", "business report", "professional writing", "executive summary"], Domain.JOURNALISM : ["news article", "journalism", "press release", "news report", "media content", "reporting"], Domain.SOCIAL_MEDIA : ["social media post", "casual writing", "online content", "informal text", "social media content"], Domain.BLOG_PERSONAL : ["personal blog", "personal writing", "lifestyle blog", "personal experience", "opinion piece", "diary entry"], Domain.LEGAL : ["legal document", "contract", "legal writing", "law", "legal agreement", "legal analysis"], Domain.MEDICAL : ["medical document", "healthcare", "clinical", "medical report", "health information", "medical research"], Domain.MARKETING : ["marketing content", "advertising", "brand content", "promotional writing", "sales copy", "marketing material"], Domain.TUTORIAL : ["tutorial", "how-to guide", "instructional content", "step-by-step guide", "educational guide", "learning material"], Domain.GENERAL : ["general content", "everyday writing", "common text", "standard writing", "normal text", "general information"], } def __init__(self): self.model_manager = get_model_manager() self.primary_classifier = None self.fallback_classifier = None self.is_initialized = False def initialize(self) -> bool: """ Initialize the domain classifier with zero-shot models """ try: logger.info("Initializing domain classifier...") # Load primary domain classifier (zero-shot) self.primary_classifier = self.model_manager.load_model(model_name = "domain_classifier") # Load fallback classifier try: self.fallback_classifier = self.model_manager.load_model(model_name = "domain_classifier_fallback") logger.info("Fallback classifier loaded successfully") except Exception as e: logger.warning(f"Could not load fallback classifier: {repr(e)}") self.fallback_classifier = None self.is_initialized = True logger.success("Domain classifier initialized successfully") return True except Exception as e: logger.error(f"Failed to initialize domain classifier: {repr(e)}") return False def classify(self, text: str, top_k: int = 2, min_confidence: float = 0.3) -> DomainPrediction: """ Classify text into domain using zero-shot classification Arguments: ---------- text { str } : Input text top_k { int } : Number of top domains to consider min_confidence { float } : Minimum confidence threshold Returns: -------- { DomainPrediction } : DomainPrediction object """ if not self.is_initialized: logger.warning("Domain classifier not initialized, initializing now...") if not self.initialize(): return self._get_default_prediction() try: # First try with primary classifier primary_result = self._classify_with_model(text = text, classifier = self.primary_classifier, model_type = "primary", ) # If primary result meets confidence threshold, return it if (primary_result.confidence >= min_confidence): return primary_result # If primary is low confidence but we have fallback, try fallback if self.fallback_classifier: logger.info("Primary classifier low confidence, trying fallback model...") fallback_result = self._classify_with_model(text = text, classifier = self.fallback_classifier, model_type = "fallback", ) # Use fallback if it has higher confidence if fallback_result.confidence > primary_result.confidence: return fallback_result # Return primary result even if low confidence return primary_result except Exception as e: logger.error(f"Error in domain classification: {repr(e)}") # Try fallback classifier if primary failed if self.fallback_classifier: try: logger.info("Trying fallback classifier after primary failure...") return self._classify_with_model(text = text, classifier = self.fallback_classifier, model_type = "fallback", ) except Exception as fallback_error: logger.error(f"Fallback classifier also failed: {repr(fallback_error)}") # Both models failed, return default return self._get_default_prediction() def _classify_with_model(self, text: str, classifier, model_type: str) -> DomainPrediction: """ Classify using a zero-shot classification model """ # Preprocess text processed_text = self._preprocess_text(text) # Get all candidate labels all_labels = list() label_to_domain = dict() for domain, labels in self.DOMAIN_LABELS.items(): # Use the first label as the primary one for this domain primary_label = labels[0] all_labels.append(primary_label) label_to_domain[primary_label] = domain # Perform zero-shot classification result = classifier(processed_text, candidate_labels = all_labels, multi_label = False, hypothesis_template = "This text is about {}.", ) # Convert to domain scores domain_scores = dict() for label, score in zip(result['labels'], result['scores']): domain = label_to_domain[label] domain_key = domain.value if (domain_key not in domain_scores): domain_scores[domain_key] = list() domain_scores[domain_key].append(score) # Average scores for each domain avg_domain_scores = {domain: sum(scores) / len(scores) for domain, scores in domain_scores.items()} # Sort by score sorted_domains = sorted(avg_domain_scores.items(), key = lambda x: x[1], reverse = True) # Get primary and secondary domains primary_domain_str, primary_score = sorted_domains[0] primary_domain = Domain(primary_domain_str) secondary_domain = None secondary_score = 0.0 if ((len(sorted_domains) > 1) and (sorted_domains[1][1] >= 0.1)): secondary_domain = Domain(sorted_domains[1][0]) secondary_score = sorted_domains[1][1] # Calculate confidence confidence = primary_score # If we have mixed domains with close scores, adjust confidence if (secondary_domain and (primary_score < 0.7) and (secondary_score > 0.3)): score_ratio = secondary_score / primary_score # Secondary is at least 60% of primary if (score_ratio > 0.6): # Lower confidence for mixed domains confidence = (primary_score + secondary_score) / 2 * 0.8 logger.info(f"Mixed domain detected: {primary_domain.value} + {secondary_domain.value}, will use interpolated thresholds") # If primary score is low and we have a secondary, it's uncertain elif ((primary_score < 0.5) and secondary_domain): # Reduce confidence confidence *= 0.8 logger.info(f"{model_type.capitalize()} model classified domain: {primary_domain.value} (confidence: {confidence:.3f})") return DomainPrediction(primary_domain = primary_domain, secondary_domain = secondary_domain, confidence = confidence, domain_scores = avg_domain_scores, ) def _preprocess_text(self, text: str) -> str: """ Preprocess text for classification """ # Truncate to reasonable length words = text.split() if (len(words) > 400): text = ' '.join(words[:400]) # Clean up text text = text.strip() if not text: return "general content" return text def _get_default_prediction(self) -> DomainPrediction: """ Get default prediction when classification fails """ return DomainPrediction(primary_domain = Domain.GENERAL, secondary_domain = None, confidence = 0.5, domain_scores = {Domain.GENERAL.value: 1.0}, ) def get_adaptive_thresholds(self, domain_prediction: DomainPrediction): """ Get adaptive thresholds based on domain prediction """ if ((domain_prediction.confidence > 0.7) and (not domain_prediction.secondary_domain)): return get_threshold_for_domain(domain_prediction.primary_domain) if domain_prediction.secondary_domain: primary_score = domain_prediction.domain_scores.get(domain_prediction.primary_domain.value, 0) secondary_score = domain_prediction.domain_scores.get(domain_prediction.secondary_domain.value, 0) if (primary_score + secondary_score > 0): weight1 = primary_score / (primary_score + secondary_score) else: weight1 = domain_prediction.confidence return interpolate_thresholds(domain1 = domain_prediction.primary_domain, domain2 = domain_prediction.secondary_domain, weight1 = weight1, ) if (domain_prediction.confidence < 0.6): return interpolate_thresholds(domain1 = domain_prediction.primary_domain, domain2 = Domain.GENERAL, weight1 = domain_prediction.confidence, ) return get_threshold_for_domain(domain_prediction.primary_domain) def cleanup(self): """ Clean up resources """ self.primary_classifier = None self.fallback_classifier = None self.is_initialized = False # Export __all__ = ["DomainClassifier", "DomainPrediction", ]