|
|
"""
|
|
|
FILE: intent_classifier.py (ENHANCED VERSION)
|
|
|
|
|
|
PURPOSE:
|
|
|
- Detect user intent from natural text using Semantic Similarity
|
|
|
- Route query into domains aligned with the expanded dataset: FOOD / HERITAGE / TRAVEL / CULTURAL / ARCHITECTURE / NATURE
|
|
|
- Uses SentenceTransformer for robust understanding beyond keywords
|
|
|
"""
|
|
|
|
|
|
from sentence_transformers import SentenceTransformer, util
|
|
|
import torch
|
|
|
|
|
|
|
|
|
DOMAIN_PROTOTYPES = {
|
|
|
"food": "I want to eat delicious food, snacks, dishes, sweets, desserts, breakfast, lunch, dinner, restaurants, street food, traditional cuisine, beverages.",
|
|
|
|
|
|
"heritage": "I want to visit historical sites, ancient monuments, forts, palaces, museums, archaeological ruins, tombs, heritage buildings, UNESCO sites, historical architecture.",
|
|
|
|
|
|
"travel": "I want to travel and explore places like hill stations, valleys, mountains, passes, trekking spots, adventure destinations, offbeat locations, hidden gems, scenic viewpoints, roads.",
|
|
|
|
|
|
"nature": "I want to experience nature through waterfalls, lakes, rivers, forests, wildlife sanctuaries, national parks, caves, islands, beaches, natural landscapes, gardens.",
|
|
|
|
|
|
"cultural": "I want to experience culture through festivals, traditional events, art forms, folk performances, local customs, tribal culture, villages, markets, handlooms, crafts.",
|
|
|
|
|
|
"architecture": "I want to see beautiful architecture, design, structures, buildings, temples, churches, mosques, monasteries, modern architecture, engineering marvels."
|
|
|
}
|
|
|
|
|
|
class IntentClassifier:
|
|
|
def __init__(self, model_name="all-MiniLM-L6-v2"):
|
|
|
print(f"π§ Loading Intent Classifier Model: {model_name}...")
|
|
|
self.model = SentenceTransformer(model_name)
|
|
|
|
|
|
|
|
|
self.domains = list(DOMAIN_PROTOTYPES.keys())
|
|
|
self.prototypes = list(DOMAIN_PROTOTYPES.values())
|
|
|
self.prototype_embeddings = self.model.encode(self.prototypes, convert_to_tensor=True)
|
|
|
print("β
Intent Classifier Ready")
|
|
|
|
|
|
def predict_intent(self, query: str, threshold: float = 0.25) -> str:
|
|
|
"""
|
|
|
Predicts the intent of the query based on semantic similarity to domain prototypes.
|
|
|
Returns 'general' if the highest similarity score is below the threshold.
|
|
|
"""
|
|
|
query_embedding = self.model.encode(query, convert_to_tensor=True)
|
|
|
|
|
|
|
|
|
cosine_scores = util.cos_sim(query_embedding, self.prototype_embeddings)[0]
|
|
|
|
|
|
|
|
|
best_score, best_index = torch.max(cosine_scores, dim=0)
|
|
|
best_score = best_score.item()
|
|
|
best_domain = self.domains[best_index]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if best_score < threshold:
|
|
|
return "general"
|
|
|
|
|
|
return best_domain
|
|
|
|
|
|
|
|
|
_shared_classifier = None
|
|
|
|
|
|
def classify_intent(query: str):
|
|
|
global _shared_classifier
|
|
|
if _shared_classifier is None:
|
|
|
_shared_classifier = IntentClassifier()
|
|
|
return _shared_classifier.predict_intent(query)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
classifier = IntentClassifier()
|
|
|
|
|
|
test_queries = [
|
|
|
|
|
|
("spicy masala dosa bangalore", "food"),
|
|
|
("sweet Indian dessert", "food"),
|
|
|
("iconic restaurant butter chicken", "food"),
|
|
|
|
|
|
|
|
|
("ancient fort in rajasthan", "heritage"),
|
|
|
("historical monuments", "heritage"),
|
|
|
("mughal palace", "heritage"),
|
|
|
|
|
|
|
|
|
("valley of flowers nagaland", "travel"),
|
|
|
("hill station honeymoon", "travel"),
|
|
|
("trekking adventure ladakh", "travel"),
|
|
|
("mountain pass", "travel"),
|
|
|
|
|
|
|
|
|
("hidden waterfall meghalaya", "nature"),
|
|
|
("wildlife sanctuary", "nature"),
|
|
|
("national park tigers", "nature"),
|
|
|
("beautiful lake", "nature"),
|
|
|
|
|
|
|
|
|
("traditional festival", "cultural"),
|
|
|
("tribal village", "cultural"),
|
|
|
("folk art performance", "cultural"),
|
|
|
|
|
|
|
|
|
("temple architecture", "architecture"),
|
|
|
("beautiful building design", "architecture"),
|
|
|
|
|
|
|
|
|
("random gibberish text", "general"),
|
|
|
]
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
print("INTENT CLASSIFIER TEST RESULTS")
|
|
|
print("="*60)
|
|
|
|
|
|
correct = 0
|
|
|
total = len(test_queries)
|
|
|
|
|
|
for query, expected in test_queries:
|
|
|
predicted = classifier.predict_intent(query)
|
|
|
is_correct = predicted == expected
|
|
|
correct += is_correct
|
|
|
|
|
|
status = "β
" if is_correct else "β"
|
|
|
print(f"{status} Query: '{query}'")
|
|
|
print(f" Expected: {expected} | Predicted: {predicted}\n")
|
|
|
|
|
|
print("="*60)
|
|
|
print(f"ACCURACY: {correct}/{total} ({round(correct/total * 100, 1)}%)")
|
|
|
print("="*60)
|
|
|
|