new_recommender_system_nlp2 / intent_classifier.py
bharatverse11's picture
Upload 11 files
7c8312b verified
raw
history blame
5.31 kB
"""
FILE: intent_classifier.py (ENHANCED VERSION)
PURPOSE:
- Detect user intent from natural text using Semantic Similarity
- Route query into domains aligned with the expanded dataset: FOOD / HERITAGE / TRAVEL / CULTURAL / ARCHITECTURE / NATURE
- Uses SentenceTransformer for robust understanding beyond keywords
"""
from sentence_transformers import SentenceTransformer, util
import torch
# Define domain prototypes aligned with the dataset's actual domains
DOMAIN_PROTOTYPES = {
"food": "I want to eat delicious food, snacks, dishes, sweets, desserts, breakfast, lunch, dinner, restaurants, street food, traditional cuisine, beverages.",
"heritage": "I want to visit historical sites, ancient monuments, forts, palaces, museums, archaeological ruins, tombs, heritage buildings, UNESCO sites, historical architecture.",
"travel": "I want to travel and explore places like hill stations, valleys, mountains, passes, trekking spots, adventure destinations, offbeat locations, hidden gems, scenic viewpoints, roads.",
"nature": "I want to experience nature through waterfalls, lakes, rivers, forests, wildlife sanctuaries, national parks, caves, islands, beaches, natural landscapes, gardens.",
"cultural": "I want to experience culture through festivals, traditional events, art forms, folk performances, local customs, tribal culture, villages, markets, handlooms, crafts.",
"architecture": "I want to see beautiful architecture, design, structures, buildings, temples, churches, mosques, monasteries, modern architecture, engineering marvels."
}
class IntentClassifier:
def __init__(self, model_name="all-MiniLM-L6-v2"):
print(f"🧠 Loading Intent Classifier Model: {model_name}...")
self.model = SentenceTransformer(model_name)
# Pre-compute embeddings for domain prototypes
self.domains = list(DOMAIN_PROTOTYPES.keys())
self.prototypes = list(DOMAIN_PROTOTYPES.values())
self.prototype_embeddings = self.model.encode(self.prototypes, convert_to_tensor=True)
print("βœ… Intent Classifier Ready")
def predict_intent(self, query: str, threshold: float = 0.25) -> str:
"""
Predicts the intent of the query based on semantic similarity to domain prototypes.
Returns 'general' if the highest similarity score is below the threshold.
"""
query_embedding = self.model.encode(query, convert_to_tensor=True)
# Compute cosine similarity
cosine_scores = util.cos_sim(query_embedding, self.prototype_embeddings)[0]
# Find the best match
best_score, best_index = torch.max(cosine_scores, dim=0)
best_score = best_score.item()
best_domain = self.domains[best_index]
# Debug output (can be uncommented for testing)
# print(f"DEBUG: Query='{query}' | Best Match='{best_domain}' ({best_score:.4f})")
if best_score < threshold:
return "general"
return best_domain
# Backwards compatibility wrapper
_shared_classifier = None
def classify_intent(query: str):
global _shared_classifier
if _shared_classifier is None:
_shared_classifier = IntentClassifier()
return _shared_classifier.predict_intent(query)
# Enhanced test suite
if __name__ == "__main__":
classifier = IntentClassifier()
test_queries = [
# Food queries
("spicy masala dosa bangalore", "food"),
("sweet Indian dessert", "food"),
("iconic restaurant butter chicken", "food"),
# Heritage queries
("ancient fort in rajasthan", "heritage"),
("historical monuments", "heritage"),
("mughal palace", "heritage"),
# Travel queries
("valley of flowers nagaland", "travel"),
("hill station honeymoon", "travel"),
("trekking adventure ladakh", "travel"),
("mountain pass", "travel"),
# Nature queries
("hidden waterfall meghalaya", "nature"),
("wildlife sanctuary", "nature"),
("national park tigers", "nature"),
("beautiful lake", "nature"),
# Cultural queries
("traditional festival", "cultural"),
("tribal village", "cultural"),
("folk art performance", "cultural"),
# Architecture queries
("temple architecture", "architecture"),
("beautiful building design", "architecture"),
# General/Ambiguous
("random gibberish text", "general"),
]
print("\n" + "="*60)
print("INTENT CLASSIFIER TEST RESULTS")
print("="*60)
correct = 0
total = len(test_queries)
for query, expected in test_queries:
predicted = classifier.predict_intent(query)
is_correct = predicted == expected
correct += is_correct
status = "βœ…" if is_correct else "❌"
print(f"{status} Query: '{query}'")
print(f" Expected: {expected} | Predicted: {predicted}\n")
print("="*60)
print(f"ACCURACY: {correct}/{total} ({round(correct/total * 100, 1)}%)")
print("="*60)