File size: 5,305 Bytes
7c8312b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
"""
FILE: intent_classifier.py (ENHANCED VERSION)
PURPOSE:
- Detect user intent from natural text using Semantic Similarity
- Route query into domains aligned with the expanded dataset: FOOD / HERITAGE / TRAVEL / CULTURAL / ARCHITECTURE / NATURE
- Uses SentenceTransformer for robust understanding beyond keywords
"""
from sentence_transformers import SentenceTransformer, util
import torch
# Define domain prototypes aligned with the dataset's actual domains
DOMAIN_PROTOTYPES = {
"food": "I want to eat delicious food, snacks, dishes, sweets, desserts, breakfast, lunch, dinner, restaurants, street food, traditional cuisine, beverages.",
"heritage": "I want to visit historical sites, ancient monuments, forts, palaces, museums, archaeological ruins, tombs, heritage buildings, UNESCO sites, historical architecture.",
"travel": "I want to travel and explore places like hill stations, valleys, mountains, passes, trekking spots, adventure destinations, offbeat locations, hidden gems, scenic viewpoints, roads.",
"nature": "I want to experience nature through waterfalls, lakes, rivers, forests, wildlife sanctuaries, national parks, caves, islands, beaches, natural landscapes, gardens.",
"cultural": "I want to experience culture through festivals, traditional events, art forms, folk performances, local customs, tribal culture, villages, markets, handlooms, crafts.",
"architecture": "I want to see beautiful architecture, design, structures, buildings, temples, churches, mosques, monasteries, modern architecture, engineering marvels."
}
class IntentClassifier:
def __init__(self, model_name="all-MiniLM-L6-v2"):
print(f"π§ Loading Intent Classifier Model: {model_name}...")
self.model = SentenceTransformer(model_name)
# Pre-compute embeddings for domain prototypes
self.domains = list(DOMAIN_PROTOTYPES.keys())
self.prototypes = list(DOMAIN_PROTOTYPES.values())
self.prototype_embeddings = self.model.encode(self.prototypes, convert_to_tensor=True)
print("β
Intent Classifier Ready")
def predict_intent(self, query: str, threshold: float = 0.25) -> str:
"""
Predicts the intent of the query based on semantic similarity to domain prototypes.
Returns 'general' if the highest similarity score is below the threshold.
"""
query_embedding = self.model.encode(query, convert_to_tensor=True)
# Compute cosine similarity
cosine_scores = util.cos_sim(query_embedding, self.prototype_embeddings)[0]
# Find the best match
best_score, best_index = torch.max(cosine_scores, dim=0)
best_score = best_score.item()
best_domain = self.domains[best_index]
# Debug output (can be uncommented for testing)
# print(f"DEBUG: Query='{query}' | Best Match='{best_domain}' ({best_score:.4f})")
if best_score < threshold:
return "general"
return best_domain
# Backwards compatibility wrapper
_shared_classifier = None
def classify_intent(query: str):
global _shared_classifier
if _shared_classifier is None:
_shared_classifier = IntentClassifier()
return _shared_classifier.predict_intent(query)
# Enhanced test suite
if __name__ == "__main__":
classifier = IntentClassifier()
test_queries = [
# Food queries
("spicy masala dosa bangalore", "food"),
("sweet Indian dessert", "food"),
("iconic restaurant butter chicken", "food"),
# Heritage queries
("ancient fort in rajasthan", "heritage"),
("historical monuments", "heritage"),
("mughal palace", "heritage"),
# Travel queries
("valley of flowers nagaland", "travel"),
("hill station honeymoon", "travel"),
("trekking adventure ladakh", "travel"),
("mountain pass", "travel"),
# Nature queries
("hidden waterfall meghalaya", "nature"),
("wildlife sanctuary", "nature"),
("national park tigers", "nature"),
("beautiful lake", "nature"),
# Cultural queries
("traditional festival", "cultural"),
("tribal village", "cultural"),
("folk art performance", "cultural"),
# Architecture queries
("temple architecture", "architecture"),
("beautiful building design", "architecture"),
# General/Ambiguous
("random gibberish text", "general"),
]
print("\n" + "="*60)
print("INTENT CLASSIFIER TEST RESULTS")
print("="*60)
correct = 0
total = len(test_queries)
for query, expected in test_queries:
predicted = classifier.predict_intent(query)
is_correct = predicted == expected
correct += is_correct
status = "β
" if is_correct else "β"
print(f"{status} Query: '{query}'")
print(f" Expected: {expected} | Predicted: {predicted}\n")
print("="*60)
print(f"ACCURACY: {correct}/{total} ({round(correct/total * 100, 1)}%)")
print("="*60)
|