bharatverse11 commited on
Commit
5837b53
Β·
verified Β·
1 Parent(s): 490929f

Update intent_classifier.py

Browse files
Files changed (1) hide show
  1. intent_classifier.py +130 -130
intent_classifier.py CHANGED
@@ -1,130 +1,130 @@
1
- """
2
- FILE: intent_classifier.py (ENHANCED VERSION)
3
-
4
- PURPOSE:
5
- - Detect user intent from natural text using Semantic Similarity
6
- - Route query into domains aligned with the expanded dataset: FOOD / HERITAGE / TRAVEL / CULTURAL / ARCHITECTURE / NATURE
7
- - Uses SentenceTransformer for robust understanding beyond keywords
8
- """
9
-
10
- from sentence_transformers import SentenceTransformer, util
11
- import torch
12
-
13
- # Define domain prototypes aligned with the dataset's actual domains
14
- DOMAIN_PROTOTYPES = {
15
- "food": "I want to eat delicious food, snacks, dishes, sweets, desserts, breakfast, lunch, dinner, restaurants, street food, traditional cuisine, beverages.",
16
-
17
- "heritage": "I want to visit historical sites, ancient monuments, forts, palaces, museums, archaeological ruins, tombs, heritage buildings, UNESCO sites, historical architecture.",
18
-
19
- "travel": "I want to travel and explore places like hill stations, valleys, mountains, passes, trekking spots, adventure destinations, offbeat locations, hidden gems, scenic viewpoints, roads.",
20
-
21
- "nature": "I want to experience nature through waterfalls, lakes, rivers, forests, wildlife sanctuaries, national parks, caves, islands, beaches, natural landscapes, gardens.",
22
-
23
- "cultural": "I want to experience culture through festivals, traditional events, art forms, folk performances, local customs, tribal culture, villages, markets, handlooms, crafts.",
24
-
25
- "architecture": "I want to see beautiful architecture, design, structures, buildings, temples, churches, mosques, monasteries, modern architecture, engineering marvels."
26
- }
27
-
28
- class IntentClassifier:
29
- def __init__(self, model_name="all-MiniLM-L6-v2"):
30
- print(f"🧠 Loading Intent Classifier Model: {model_name}...")
31
- self.model = SentenceTransformer(model_name)
32
-
33
- # Pre-compute embeddings for domain prototypes
34
- self.domains = list(DOMAIN_PROTOTYPES.keys())
35
- self.prototypes = list(DOMAIN_PROTOTYPES.values())
36
- self.prototype_embeddings = self.model.encode(self.prototypes, convert_to_tensor=True)
37
- print("βœ… Intent Classifier Ready")
38
-
39
- def predict_intent(self, query: str, threshold: float = 0.25) -> str:
40
- """
41
- Predicts the intent of the query based on semantic similarity to domain prototypes.
42
- Returns 'general' if the highest similarity score is below the threshold.
43
- """
44
- query_embedding = self.model.encode(query, convert_to_tensor=True)
45
-
46
- # Compute cosine similarity
47
- cosine_scores = util.cos_sim(query_embedding, self.prototype_embeddings)[0]
48
-
49
- # Find the best match
50
- best_score, best_index = torch.max(cosine_scores, dim=0)
51
- best_score = best_score.item()
52
- best_domain = self.domains[best_index]
53
-
54
- # Debug output (can be uncommented for testing)
55
- # print(f"DEBUG: Query='{query}' | Best Match='{best_domain}' ({best_score:.4f})")
56
-
57
- if best_score < threshold:
58
- return "general"
59
-
60
- return best_domain
61
-
62
- # Backwards compatibility wrapper
63
- _shared_classifier = None
64
-
65
- def classify_intent(query: str):
66
- global _shared_classifier
67
- if _shared_classifier is None:
68
- _shared_classifier = IntentClassifier()
69
- return _shared_classifier.predict_intent(query)
70
-
71
-
72
- # Enhanced test suite
73
- if __name__ == "__main__":
74
- classifier = IntentClassifier()
75
-
76
- test_queries = [
77
- # Food queries
78
- ("spicy masala dosa bangalore", "food"),
79
- ("sweet Indian dessert", "food"),
80
- ("iconic restaurant butter chicken", "food"),
81
-
82
- # Heritage queries
83
- ("ancient fort in rajasthan", "heritage"),
84
- ("historical monuments", "heritage"),
85
- ("mughal palace", "heritage"),
86
-
87
- # Travel queries
88
- ("valley of flowers nagaland", "travel"),
89
- ("hill station honeymoon", "travel"),
90
- ("trekking adventure ladakh", "travel"),
91
- ("mountain pass", "travel"),
92
-
93
- # Nature queries
94
- ("hidden waterfall meghalaya", "nature"),
95
- ("wildlife sanctuary", "nature"),
96
- ("national park tigers", "nature"),
97
- ("beautiful lake", "nature"),
98
-
99
- # Cultural queries
100
- ("traditional festival", "cultural"),
101
- ("tribal village", "cultural"),
102
- ("folk art performance", "cultural"),
103
-
104
- # Architecture queries
105
- ("temple architecture", "architecture"),
106
- ("beautiful building design", "architecture"),
107
-
108
- # General/Ambiguous
109
- ("random gibberish text", "general"),
110
- ]
111
-
112
- print("\n" + "="*60)
113
- print("INTENT CLASSIFIER TEST RESULTS")
114
- print("="*60)
115
-
116
- correct = 0
117
- total = len(test_queries)
118
-
119
- for query, expected in test_queries:
120
- predicted = classifier.predict_intent(query)
121
- is_correct = predicted == expected
122
- correct += is_correct
123
-
124
- status = "βœ…" if is_correct else "❌"
125
- print(f"{status} Query: '{query}'")
126
- print(f" Expected: {expected} | Predicted: {predicted}\n")
127
-
128
- print("="*60)
129
- print(f"ACCURACY: {correct}/{total} ({round(correct/total * 100, 1)}%)")
130
- print("="*60)
 
1
+ """
2
+ FILE: intent_classifier.py (ENHANCED VERSION)
3
+
4
+ PURPOSE:
5
+ - Detect user intent from natural text using Semantic Similarity
6
+ - Route query into domains aligned with the expanded dataset: FOOD / HERITAGE / TRAVEL / CULTURAL / ARCHITECTURE / NATURE
7
+ - Uses SentenceTransformer for robust understanding beyond keywords
8
+ """
9
+
10
+ from sentence_transformers import SentenceTransformer, util
11
+ import torch
12
+
13
+ # Define domain prototypes aligned with the dataset's actual domains
14
+ DOMAIN_PROTOTYPES = {
15
+ "food": "I want to eat delicious food, snacks, dishes, sweets, desserts, breakfast, lunch, dinner, restaurants, street food, traditional cuisine, beverages. biryani, dosa, idli, thali, spicy food, local recipes, sweets like mysore pak, halwa.",
16
+
17
+ "heritage": "I want to visit historical sites, ancient monuments, forts, palaces, museums, archaeological ruins, tombs, heritage buildings, UNESCO sites, historical architecture. temples like tirupati, meenakshi, forts like golconda, ancient caves.",
18
+
19
+ "travel": "I want to travel and explore places like hill stations, valleys, mountains, passes, trekking spots, adventure destinations, offbeat locations, hidden gems, scenic viewpoints, roads. beaches like kothapatnam, varkala, hill stations like ooty, munnar.",
20
+
21
+ "nature": "I want to experience nature through waterfalls, lakes, rivers, forests, wildlife sanctuaries, national parks, caves, islands, beaches, natural landscapes, gardens. waterfalls, tiger reserves, sprawling lakes, botanical gardens.",
22
+
23
+ "cultural": "I want to experience culture through festivals, traditional events, art forms, folk performances, local customs, tribal culture, villages, markets, handlooms, crafts. dance forms, music festivals, handicraft markets, silk saree weaving.",
24
+
25
+ "architecture": "I want to see beautiful architecture, design, structures, buildings, temples, churches, mosques, monasteries, modern architecture, engineering marvels. dravidian style, mughal architecture, intricate carvings, massive domes."
26
+ }
27
+
28
+ class IntentClassifier:
29
+ def __init__(self, model_name="all-MiniLM-L6-v2"):
30
+ print(f"🧠 Loading Intent Classifier Model: {model_name}...")
31
+ self.model = SentenceTransformer(model_name)
32
+
33
+ # Pre-compute embeddings for domain prototypes
34
+ self.domains = list(DOMAIN_PROTOTYPES.keys())
35
+ self.prototypes = list(DOMAIN_PROTOTYPES.values())
36
+ self.prototype_embeddings = self.model.encode(self.prototypes, convert_to_tensor=True)
37
+ print("βœ… Intent Classifier Ready")
38
+
39
+ def predict_intent(self, query: str, threshold: float = 0.25) -> str:
40
+ """
41
+ Predicts the intent of the query based on semantic similarity to domain prototypes.
42
+ Returns 'general' if the highest similarity score is below the threshold.
43
+ """
44
+ query_embedding = self.model.encode(query, convert_to_tensor=True)
45
+
46
+ # Compute cosine similarity
47
+ cosine_scores = util.cos_sim(query_embedding, self.prototype_embeddings)[0]
48
+
49
+ # Find the best match
50
+ best_score, best_index = torch.max(cosine_scores, dim=0)
51
+ best_score = best_score.item()
52
+ best_domain = self.domains[best_index]
53
+
54
+ # Debug output (can be uncommented for testing)
55
+ # print(f"DEBUG: Query='{query}' | Best Match='{best_domain}' ({best_score:.4f})")
56
+
57
+ if best_score < threshold:
58
+ return "general"
59
+
60
+ return best_domain
61
+
62
+ # Backwards compatibility wrapper
63
+ _shared_classifier = None
64
+
65
+ def classify_intent(query: str):
66
+ global _shared_classifier
67
+ if _shared_classifier is None:
68
+ _shared_classifier = IntentClassifier()
69
+ return _shared_classifier.predict_intent(query)
70
+
71
+
72
+ # Enhanced test suite
73
+ if __name__ == "__main__":
74
+ classifier = IntentClassifier()
75
+
76
+ test_queries = [
77
+ # Food queries
78
+ ("spicy masala dosa bangalore", "food"),
79
+ ("sweet Indian dessert", "food"),
80
+ ("iconic restaurant butter chicken", "food"),
81
+
82
+ # Heritage queries
83
+ ("ancient fort in rajasthan", "heritage"),
84
+ ("historical monuments", "heritage"),
85
+ ("mughal palace", "heritage"),
86
+
87
+ # Travel queries
88
+ ("valley of flowers nagaland", "travel"),
89
+ ("hill station honeymoon", "travel"),
90
+ ("trekking adventure ladakh", "travel"),
91
+ ("mountain pass", "travel"),
92
+
93
+ # Nature queries
94
+ ("hidden waterfall meghalaya", "nature"),
95
+ ("wildlife sanctuary", "nature"),
96
+ ("national park tigers", "nature"),
97
+ ("beautiful lake", "nature"),
98
+
99
+ # Cultural queries
100
+ ("traditional festival", "cultural"),
101
+ ("tribal village", "cultural"),
102
+ ("folk art performance", "cultural"),
103
+
104
+ # Architecture queries
105
+ ("temple architecture", "architecture"),
106
+ ("beautiful building design", "architecture"),
107
+
108
+ # General/Ambiguous
109
+ ("random gibberish text", "general"),
110
+ ]
111
+
112
+ print("\n" + "="*60)
113
+ print("INTENT CLASSIFIER TEST RESULTS")
114
+ print("="*60)
115
+
116
+ correct = 0
117
+ total = len(test_queries)
118
+
119
+ for query, expected in test_queries:
120
+ predicted = classifier.predict_intent(query)
121
+ is_correct = predicted == expected
122
+ correct += is_correct
123
+
124
+ status = "βœ…" if is_correct else "❌"
125
+ print(f"{status} Query: '{query}'")
126
+ print(f" Expected: {expected} | Predicted: {predicted}\n")
127
+
128
+ print("="*60)
129
+ print(f"ACCURACY: {correct}/{total} ({round(correct/total * 100, 1)}%)")
130
+ print("="*60)