IniNLP247 commited on
Commit
d0b211e
Β·
verified Β·
1 Parent(s): 4206b8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +388 -186
app.py CHANGED
@@ -1,26 +1,22 @@
1
- #INFERENCE NLP+EMOTION DETECTION CV+TTS+THREAT DETECTION CV
2
  import spaces
3
 
4
  import gradio as gr
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
  import torch
7
  from deepface import DeepFace
8
- import threading
9
  import time
10
  from transformers import AutoProcessor, DiaForConditionalGeneration
11
  import numpy as np
12
- import supervision as sv
13
- import requests
14
- from PIL import Image
15
- import os
16
- from rfdetr import RFDETRNano
17
 
18
- # Model setup
19
  model_name = "IniNLP247/Kenko-mental-health-llama-3-model"
20
 
21
- print("πŸ”„ Loading Kenko Mental Health Model...")
22
 
23
- # Load tokenizer and model
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
25
  if tokenizer.pad_token is None:
26
  tokenizer.pad_token = tokenizer.eos_token
@@ -31,7 +27,6 @@ model = AutoModelForCausalLM.from_pretrained(
31
  device_map="auto"
32
  )
33
 
34
- # Create pipeline for easier inference
35
  pipe = pipeline(
36
  "text-generation",
37
  model=model,
@@ -44,51 +39,31 @@ pipe = pipeline(
44
  pad_token_id=tokenizer.pad_token_id
45
  )
46
 
47
- print("βœ… Model loaded successfully!")
 
48
 
49
- #Loading of TTS
50
  print("Loading Dia TTS Model...")
51
  tts_device = "cuda:0" if torch.cuda.is_available() else "cpu"
52
  tts_model = "nari-labs/Dia-1.6B-0626"
53
  tts_processor = AutoProcessor.from_pretrained(tts_model)
54
  tts_model = DiaForConditionalGeneration.from_pretrained(tts_model, torch_dtype=torch.float16).to(tts_device)
55
- print("βœ… Dia TTS Model loaded successfully!")
56
 
57
- THREAT_CLASSES = {
58
- 1: "Gun",
59
- 2: "Explosive",
60
- 3: "Grenade",
61
- 4: "Knife"
62
- }
63
 
64
- #Loading Threat Detection Model
65
- print("Loading Threat Detection Model...")
66
- threat_weights_url = "https://huggingface.co/Subh775/Threat-Detection-RF-DETR/resolve/main/checkpoint_best_total.pth"
67
- threat_weights_filename = "checkpoint_best_total.pth"
68
-
69
- # Download weights if not already present
70
- if not os.path.exists(threat_weights_filename):
71
- print(f"Downloading weights from {threat_weights_url}")
72
- response = requests.get(threat_weights_url, stream=True)
73
- response.raise_for_status()
74
- with open(threat_weights_filename, 'wb') as f:
75
- for chunk in response.iter_content(chunk_size=8192):
76
- f.write(chunk)
77
- print("Download complete.")
78
-
79
- threat_model = RFDETRNano(resolution=640, pretrain_weights=threat_weights_filename)
80
- #threat_model.optimize_for_inference()
81
-
82
- print("βœ… Threat Detection Model loaded successfully!")
83
-
84
- #Global Variables For Threat Detection
85
- current_threat_state = {
86
- "threat_detected": [],
87
- "threat_count": 0,
88
- "last_update": None
89
- }
90
 
91
- # Global variable to store current emotion state
92
  current_emotion_state = {
93
  "dominant": "neutral",
94
  "confidence": 0.0,
@@ -96,8 +71,359 @@ current_emotion_state = {
96
  "last_update": None
97
  }
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def update_emotion_status():
100
- """Update emotion status text"""
101
  if current_emotion_state["last_update"] is None:
102
  return "*Waiting for emotion data...*"
103
 
@@ -109,22 +435,7 @@ def update_emotion_status():
109
  confidence = current_emotion_state["confidence"]
110
  return f"**Current Emotion:** {dominant.capitalize()} ({confidence:.1f}% confidence)\n*Last updated: {int(elapsed)}s ago*"
111
 
112
- def update_threat_status():
113
- """Update threat status text"""
114
- if current_threat_state["last_update"] is None:
115
- return "*Monitoring for threats...*"
116
-
117
- elapsed = time.time() - current_threat_state["last_update"]
118
-
119
- threats = current_threat_state["threat_detected"]
120
- if threats:
121
- threat_list = ", ".join([t["type"] for t in threats])
122
- return f"**⚠️ ALERT:** {threat_list} detected\n*Last updated: {int(elapsed)}s ago*"
123
- else:
124
- return f"**βœ… Safe:** No threats detected\n*Last updated: {int(elapsed)}s ago*"
125
-
126
  def analyze_emotion(image):
127
- """Analyze emotion from webcam image"""
128
  global current_emotion_state
129
 
130
  try:
@@ -145,7 +456,6 @@ def analyze_emotion(image):
145
  emotions = result['emotion']
146
  dominant = result['dominant_emotion']
147
 
148
- # Update global emotion state
149
  current_emotion_state = {
150
  "dominant": dominant,
151
  "confidence": emotions[dominant],
@@ -153,10 +463,9 @@ def analyze_emotion(image):
153
  "last_update": time.time()
154
  }
155
 
156
- # Format for display - REMOVE the % symbol and keep as numbers
157
  output = {}
158
  for emotion, score in sorted(emotions.items(), key=lambda x: x[1], reverse=True):
159
- output[emotion.capitalize()] = score # Just the number, no formatting
160
 
161
  return output
162
 
@@ -169,7 +478,6 @@ def get_emotion_context():
169
  if current_emotion_state["last_update"] is None:
170
  return ""
171
 
172
- # Check if emotion data is recent (within last 60 seconds)
173
  if time.time() - current_emotion_state["last_update"] > 60:
174
  return ""
175
 
@@ -182,28 +490,21 @@ def get_emotion_context():
182
  def chat_with_kenko(message, history):
183
  """Chat function for Gradio interface with emotion awareness"""
184
 
185
- # Build conversation context
186
  conversation = ""
187
  for user_msg, bot_msg in history:
188
  conversation += f"User: {user_msg}\nKenko: {bot_msg}\n\n"
189
 
190
- # Get emotion context
191
  emotion_context = get_emotion_context()
192
 
193
- # Get threat context
194
  threat_context = get_threat_context()
195
 
196
- # Create prompt in instruction format with emotion awareness
197
  prompt = f"""### Instruction:
198
  You are Kenko, a compassionate mental health therapist. Provide empathetic, helpful, and professional responses to support the user's mental wellbeing.
199
  {emotion_context}{threat_context}
200
-
201
  {conversation}User: {message}
202
-
203
  ### Response:
204
  """
205
 
206
- # Generate response
207
  try:
208
  response = pipe(prompt)[0]['generated_text']
209
  return response.strip()
@@ -216,23 +517,19 @@ def generate_tts(text):
216
 
217
  print(f"[TTS] Generating speech for {len(text)} chars: '{text[:50]}...'")
218
 
219
- # Prepare inputs for Dia TTS
220
  inputs = tts_processor(text=text, return_tensors="pt", padding=True)
221
  inputs = {k: v.to(tts_device) for k, v in inputs.items()}
222
 
223
  print(f"[TTS] Inputs prepared, generating audio codes...")
224
 
225
- # Generate audio codes
226
  with torch.no_grad():
227
  generated_ids = tts_model.generate(**inputs, max_length=2500)
228
 
229
  print(f"[TTS] Audio codes generated, shape: {generated_ids.shape}")
230
  print(f"[TTS] Decoding codes to waveform...")
231
 
232
- # Decode the audio codes to waveform using the processor's batch_decode
233
  audio_values = tts_processor.batch_decode(generated_ids, return_tensors="pt")
234
 
235
- # Extract the audio waveform
236
  if isinstance(audio_values, dict) and 'audio_values' in audio_values:
237
  audio_arr = audio_values['audio_values'][0].cpu().numpy()
238
  elif isinstance(audio_values, torch.Tensor):
@@ -242,93 +539,24 @@ def generate_tts(text):
242
  else:
243
  audio_arr = np.array(audio_values).squeeze()
244
 
245
- # Ensure float32
246
  audio_arr = audio_arr.astype(np.float32)
247
 
248
- # Dia uses 44.1kHz sample rate
249
  sample_rate = 44100
250
 
251
- print(f"βœ… [TTS] Audio decoded: {len(audio_arr)} samples at {sample_rate}Hz = {len(audio_arr)/sample_rate:.2f} seconds")
252
 
253
  if len(audio_arr) == 0:
254
- print("❌ Decoded audio is empty!")
255
  return None
256
 
257
  return (sample_rate, audio_arr)
258
 
259
  except Exception as e:
260
- print(f"❌ TTS generation error: {str(e)}")
261
  import traceback
262
  traceback.print_exc()
263
  return None
264
 
265
- def threat_detection(image):
266
- """Threat detection function for webcam"""
267
- global current_threat_state
268
-
269
- try:
270
- if image is None:
271
- return {}
272
-
273
- # Convert numpy array to PIL Image if needed
274
- if isinstance(image, np.ndarray):
275
- image = Image.fromarray(image)
276
-
277
- # Run Threat Detection
278
- detections = threat_model.predict(image, threshold=0.3) # Lower threshold for testing
279
-
280
- # Parse detections - detections.class_id and detections.confidence are ARRAYS
281
- threat_found = []
282
- if detections is not None and len(detections.class_id) > 0:
283
- for class_id, confidence in zip(detections.class_id, detections.confidence):
284
- class_id = int(class_id)
285
- confidence = float(confidence)
286
-
287
- if class_id in THREAT_CLASSES:
288
- threat_name = THREAT_CLASSES[class_id]
289
- threat_found.append({"type": threat_name, "confidence": confidence})
290
- print(f"🚨 THREAT DETECTED: {threat_name} - {confidence:.2%}")
291
-
292
- # Update global threat state
293
- current_threat_state = {
294
- "threat_detected": threat_found,
295
- "threat_count": len(threat_found),
296
- "last_update": time.time()
297
- }
298
-
299
- # Format for display
300
- if threat_found:
301
- output = {}
302
- for threat in threat_found:
303
- output[threat["type"]] = threat["confidence"] * 100
304
- return output
305
- else:
306
- return {"No threats detected": 100.0}
307
-
308
- except Exception as e:
309
- print(f"Threat detection error: {str(e)}")
310
- import traceback
311
- traceback.print_exc()
312
- return {}
313
-
314
-
315
- def get_threat_context():
316
- """Get current threat as context string for the model"""
317
- if current_threat_state["last_update"] is None:
318
- return ""
319
-
320
- #Check if threat data is recent (within last 60 seconds)
321
- if time.time() - current_threat_state["last_update"] > 60:
322
- return ""
323
-
324
- threats = current_threat_state["threat_detected"]
325
-
326
- if threats:
327
- threat_list = ", ".join([f"{t['type']} ({t['confidence']*100:.1f}% confidence)" for t in threats])
328
- return f"\n[User currently holds a potential threat: {threat_list}]"
329
-
330
- return ""
331
-
332
  # Custom CSS for a calming interface
333
  css = """
334
  .gradio-container {
@@ -342,7 +570,6 @@ css = """
342
  }
343
  """
344
 
345
- # Create Gradio interface
346
  with gr.Blocks(
347
  title="Kenko - Mental Health Assistant",
348
  theme=gr.themes.Soft(),
@@ -350,16 +577,13 @@ with gr.Blocks(
350
  ) as demo:
351
 
352
  gr.Markdown("""
353
- # πŸ§ πŸ’š Kenko - Your Emotion-Aware Mental Health Assistant
354
-
355
  Welcome! I'm Kenko, an AI mental health therapist enhanced with real-time emotion detection.
356
  Allow webcam access to enable emotion-aware responses that adapt to how you're feeling.
357
-
358
  *Please remember: I'm an AI assistant and cannot replace professional mental health care. In crisis situations, please contact emergency services or a mental health professional.*
359
  """)
360
 
361
  with gr.Row():
362
- # Left column: Chat interface
363
  with gr.Column(scale=2):
364
  chatbot = gr.Chatbot(
365
  height=500,
@@ -389,7 +613,6 @@ with gr.Blocks(
389
  clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", scale=1, variant="secondary")
390
  examples_btn = gr.Button("πŸ’‘ Example Topics", scale=1, variant="secondary")
391
 
392
- # Right column: Emotion detection
393
  with gr.Column(scale=1):
394
  gr.Markdown("### πŸ“Έ Emotion Detection")
395
  gr.Markdown("*Your emotional state helps me provide more personalized support*")
@@ -408,17 +631,7 @@ with gr.Blocks(
408
 
409
  emotion_status = gr.Markdown("*Waiting for emotion data...*")
410
 
411
- #Threat detection output
412
- gr.Markdown("### Threat Detection")
413
- threat_output = gr.Label(
414
- num_top_classes=4,
415
- label="Detected Threats"
416
- )
417
-
418
- threat_status = gr.Markdown("*Monitoring for threats...")
419
-
420
-
421
- # Example prompts
422
  with gr.Row(visible=False) as examples_row:
423
  gr.Examples(
424
  examples=[
@@ -444,19 +657,16 @@ with gr.Blocks(
444
  - Relationship and communication advice
445
  - Mindfulness and self-care suggestions
446
  - Building healthy habits and routines
447
-
448
  **Emotion Detection Feature:**
449
  - Real-time facial emotion analysis
450
  - Adapts responses based on your current emotional state
451
  - Updates automatically every 30 seconds
452
  - Completely optional - works without webcam too
453
-
454
  **Important Notes:**
455
  - I'm an AI trained to provide mental health support
456
  - For immediate crisis support, contact emergency services (911) or crisis hotlines
457
  - Consider professional therapy for ongoing mental health needs
458
  - I don't diagnose conditions or prescribe medications
459
-
460
  **Privacy:** Your conversations and emotion data are not stored or shared.
461
  """)
462
  @spaces.GPU
@@ -485,35 +695,27 @@ with gr.Blocks(
485
  return gr.Row(visible=True)
486
 
487
 
488
- # Event handlers
489
  submit = msg.submit(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot, audio_output])
490
  send = send_btn.click(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot, audio_output])
491
  clear_btn.click(lambda: [], None, outputs=[chatbot, audio_output])
492
  examples_btn.click(toggle_examples, outputs=examples_row)
493
 
494
- # Emotion detection with streaming (analyzes continuously)
495
  webcam_input.stream(
496
  analyze_emotion,
497
  inputs=webcam_input,
498
  outputs=emotion_output,
499
- stream_every=1, # Update every 1 second instead of 30
500
- time_limit=60 # Keep processing for 60 seconds
501
  )
502
 
503
- timer = gr.Timer(value=5) # Update every 5 seconds
 
504
 
505
- # Threat detection with streaming
506
- webcam_input.stream(
507
- threat_detection, # Corrected function name
508
- inputs=webcam_input, # Corrected inputs
509
- outputs=threat_output,
510
- stream_every=2,
511
- time_limit=60
512
- )
513
- # Add to timer tick
514
  timer.tick(
515
- fn=lambda: (update_emotion_status(), update_threat_status()),
516
- outputs=[emotion_status, threat_status]
517
  )
518
 
519
 
 
1
+ #INFERENCE NLP+EMOTION DETECTION CV+TTS+Memory Management
2
  import spaces
3
 
4
  import gradio as gr
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
  import torch
7
  from deepface import DeepFace
 
8
  import time
9
  from transformers import AutoProcessor, DiaForConditionalGeneration
10
  import numpy as np
11
+ import chromadb
12
+ from langchain_community.vectorstores import Chroma
13
+ from collections import defaultdict
14
+ from sklearn.cluster import DBSCAN
 
15
 
 
16
  model_name = "IniNLP247/Kenko-mental-health-llama-3-model"
17
 
18
+ print("Loading Kenko Mental Health Model...")
19
 
 
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  if tokenizer.pad_token is None:
22
  tokenizer.pad_token = tokenizer.eos_token
 
27
  device_map="auto"
28
  )
29
 
 
30
  pipe = pipeline(
31
  "text-generation",
32
  model=model,
 
39
  pad_token_id=tokenizer.pad_token_id
40
  )
41
 
42
+ print("Model loaded successfully!")
43
+
44
 
 
45
  print("Loading Dia TTS Model...")
46
  tts_device = "cuda:0" if torch.cuda.is_available() else "cpu"
47
  tts_model = "nari-labs/Dia-1.6B-0626"
48
  tts_processor = AutoProcessor.from_pretrained(tts_model)
49
  tts_model = DiaForConditionalGeneration.from_pretrained(tts_model, torch_dtype=torch.float16).to(tts_device)
50
+ print("Dia TTS Model loaded successfully!")
51
 
52
+ print("Initializing Memory Components...")
53
+ chroma_client = chromadb.Client()
54
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
 
 
 
55
 
56
+ def embed_function(texts):
57
+ if isinstance(texts, str):
58
+ texts = [texts]
59
+ return embedding_model.encode(texts).tolist()
60
+
61
+ global_vector_store = Chroma(
62
+ client=chroma_client,
63
+ embedding_function=embed_function
64
+ )
65
+ print("Memory components initialized!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
67
  current_emotion_state = {
68
  "dominant": "neutral",
69
  "confidence": 0.0,
 
71
  "last_update": None
72
  }
73
 
74
+
75
+ class AdvancedMemorySystem:
76
+ """
77
+ Multi-tier memory system inspired by human memory:
78
+ - Working Memory: Current conversation (high priority)
79
+ - Short-term Memory: Recent session with decay
80
+ - Long-term Memory: Semantic clusters of important themes
81
+ - Emotional Memory: Affective associations and patterns
82
+ """
83
+
84
+ def __init__(self, embedding_model, vector_store):
85
+ self.embedding_model = embedding_model
86
+ self.vector_store = vector_store
87
+
88
+ # Working memory - immediate context
89
+ self.working_memory = []
90
+
91
+ # Short-term memory with temporal decay
92
+ self.short_term_memory = [] # (timestamp, text, embedding, importance_score)
93
+
94
+ # Long-term semantic clusters
95
+ self.semantic_clusters = defaultdict(list) # {cluster_id: [memories]}
96
+
97
+ # Emotional memory graph
98
+ self.emotional_memory = {
99
+ "emotion_transitions": [], # Track emotional journey
100
+ "trigger_patterns": defaultdict(list), # What triggers what emotion
101
+ "coping_effectiveness": {} # Track which strategies work
102
+ }
103
+
104
+ # Meta-cognitive tracking
105
+ self.conversation_themes = []
106
+ self.user_model = {
107
+ "communication_style": None,
108
+ "recurring_concerns": [],
109
+ "progress_indicators": [],
110
+ "relational_patterns": []
111
+ }
112
+
113
+ def calculate_importance(self, text, emotion, user_engagement):
114
+ """Calculate memory importance using multiple factors"""
115
+ importance = 0.5 # Base importance
116
+
117
+ # Emotional weight
118
+ high_intensity_emotions = ["fear", "angry", "sad", "surprise"]
119
+ if emotion in high_intensity_emotions:
120
+ importance += 0.3
121
+
122
+ # Engagement weight
123
+ if len(text.split()) > 30: # Longer, detailed message
124
+ importance += 0.2
125
+
126
+ # Therapeutic keywords
127
+ therapeutic_keywords = [
128
+ "trauma", "suicide", "self-harm", "abuse", "panic",
129
+ "breakthrough", "progress", "better", "worse", "relationship"
130
+ ]
131
+ if any(kw in text.lower() for kw in therapeutic_keywords):
132
+ importance += 0.3
133
+
134
+ return min(importance, 1.0)
135
+
136
+ def add_to_working_memory(self, user_msg, bot_msg, emotion, timestamp):
137
+ """Add to immediate working memory (sliding window)"""
138
+ self.working_memory.append({
139
+ "user": user_msg,
140
+ "bot": bot_msg,
141
+ "emotion": emotion,
142
+ "timestamp": timestamp
143
+ })
144
+
145
+ # Keep only last 5 exchanges in working memory
146
+ if len(self.working_memory) > 5:
147
+ # Consolidate oldest to short-term before removing
148
+ oldest = self.working_memory.pop(0)
149
+ self._consolidate_to_short_term(oldest)
150
+
151
+ def _consolidate_to_short_term(self, memory_item):
152
+ """Move from working to short-term memory with importance scoring"""
153
+ text = f"User: {memory_item['user']}\nKenko: {memory_item['bot']}"
154
+ embedding = self.embedding_model.encode(text)
155
+
156
+ # Calculate importance
157
+ importance = self.calculate_importance(
158
+ memory_item['user'],
159
+ memory_item['emotion'],
160
+ len(memory_item['user'].split())
161
+ )
162
+
163
+ self.short_term_memory.append({
164
+ "text": text,
165
+ "embedding": embedding,
166
+ "importance": importance,
167
+ "timestamp": memory_item['timestamp'],
168
+ "emotion": memory_item['emotion']
169
+ })
170
+
171
+ # Add to vector store with importance weighting
172
+ try:
173
+ self.vector_store.add_texts(
174
+ texts=[text],
175
+ metadatas=[{"importance": importance, "timestamp": memory_item['timestamp']}]
176
+ )
177
+ except Exception as e:
178
+ print(f"Vector store error: {e}")
179
+
180
+ def apply_temporal_decay(self, current_time):
181
+ """Apply decay to short-term memories over time"""
182
+ decay_rate = 0.01 # Decay 1% per time unit (adjusted for slower decay)
183
+
184
+ for memory in self.short_term_memory:
185
+ time_elapsed = (current_time - memory['timestamp']) / 60 # Convert to minutes
186
+ decay_factor = np.exp(-decay_rate * time_elapsed)
187
+ memory['importance'] *= decay_factor
188
+
189
+ # If importance drops below threshold, cluster into long-term
190
+ if memory['importance'] < 0.15:
191
+ self._consolidate_to_long_term(memory)
192
+
193
+ def _consolidate_to_long_term(self, memory):
194
+ """Cluster similar memories into semantic long-term memory"""
195
+ # Get embeddings of all long-term memories
196
+ if not self.semantic_clusters:
197
+ self.semantic_clusters[0] = [memory]
198
+ self.short_term_memory.remove(memory)
199
+ return
200
+
201
+ # Find semantic cluster using cosine similarity
202
+ best_cluster = 0
203
+ best_similarity = -1
204
+
205
+ for cluster_id, cluster_memories in self.semantic_clusters.items():
206
+ # Compare with cluster centroid
207
+ cluster_embeddings = [m['embedding'] for m in cluster_memories]
208
+ centroid = np.mean(cluster_embeddings, axis=0)
209
+
210
+ similarity = np.dot(memory['embedding'], centroid) / (
211
+ np.linalg.norm(memory['embedding']) * np.linalg.norm(centroid)
212
+ )
213
+
214
+ if similarity > best_similarity:
215
+ best_similarity = similarity
216
+ best_cluster = cluster_id
217
+
218
+ # Add to cluster if similar enough, else create new cluster
219
+ if best_similarity > 0.7:
220
+ self.semantic_clusters[best_cluster].append(memory)
221
+ else:
222
+ new_cluster_id = max(self.semantic_clusters.keys()) + 1
223
+ self.semantic_clusters[new_cluster_id] = [memory]
224
+
225
+ # Remove from short-term
226
+ if memory in self.short_term_memory:
227
+ self.short_term_memory.remove(memory)
228
+
229
+ def track_emotional_transition(self, prev_emotion, current_emotion, context):
230
+ """Track emotional state transitions for pattern recognition"""
231
+ self.emotional_memory["emotion_transitions"].append({
232
+ "from": prev_emotion,
233
+ "to": current_emotion,
234
+ "context": context,
235
+ "timestamp": time.time()
236
+ })
237
+
238
+ # Analyze if certain topics trigger emotional shifts
239
+ if prev_emotion != current_emotion:
240
+ self.emotional_memory["trigger_patterns"][current_emotion].append(context)
241
+
242
+ def analyze_conversation_themes(self):
243
+ """Use topic modeling on conversation to identify recurring themes"""
244
+ if len(self.short_term_memory) < 3:
245
+ return []
246
+
247
+ # Simple keyword extraction
248
+ all_text = " ".join([m['text'] for m in self.short_term_memory])
249
+
250
+ # Extract key phrases
251
+ words = all_text.lower().split()
252
+ word_freq = defaultdict(int)
253
+
254
+ stopwords = {"the", "a", "is", "in", "and", "to", "of", "i", "my", "me", "you", "that", "it"}
255
+ for word in words:
256
+ if word not in stopwords and len(word) > 4:
257
+ word_freq[word] += 1
258
+
259
+ # Get top themes
260
+ themes = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)[:5]
261
+ self.conversation_themes = [theme[0] for theme in themes]
262
+
263
+ return self.conversation_themes
264
+
265
+ def retrieve_contextual_memory(self, query, current_emotion):
266
+ """Advanced retrieval using multiple memory tiers"""
267
+ context = {
268
+ "working": [],
269
+ "short_term": [],
270
+ "long_term": [],
271
+ "emotional": [],
272
+ "themes": []
273
+ }
274
+
275
+ # 1. Working memory (full recent conversation)
276
+ context["working"] = self.working_memory[-3:] # Last 3 exchanges
277
+
278
+ # 2. Short-term memory (importance-weighted retrieval)
279
+ if self.short_term_memory:
280
+ query_embedding = self.embedding_model.encode(query)
281
+
282
+ scored_memories = []
283
+ for memory in self.short_term_memory:
284
+ # Semantic similarity
285
+ similarity = np.dot(query_embedding, memory['embedding']) / (
286
+ np.linalg.norm(query_embedding) * np.linalg.norm(memory['embedding'])
287
+ )
288
+
289
+ # Boost by importance score
290
+ final_score = similarity * memory['importance']
291
+
292
+ # Emotional congruence bonus
293
+ if memory['emotion'] == current_emotion:
294
+ final_score *= 1.2
295
+
296
+ scored_memories.append((final_score, memory))
297
+
298
+ # Get top 3 short-term memories
299
+ scored_memories.sort(reverse=True, key=lambda x: x[0])
300
+ context["short_term"] = [m[1] for m in scored_memories[:3]]
301
+
302
+ # 3. Long-term memory (semantic clusters)
303
+ if self.semantic_clusters:
304
+ query_embedding = self.embedding_model.encode(query)
305
+ best_cluster_id = None
306
+ best_cluster_score = -1
307
+
308
+ for cluster_id, cluster_memories in self.semantic_clusters.items():
309
+ cluster_embeddings = [m['embedding'] for m in cluster_memories]
310
+ centroid = np.mean(cluster_embeddings, axis=0)
311
+
312
+ similarity = np.dot(query_embedding, centroid) / (
313
+ np.linalg.norm(query_embedding) * np.linalg.norm(centroid)
314
+ )
315
+
316
+ if similarity > best_cluster_score:
317
+ best_cluster_score = similarity
318
+ best_cluster_id = cluster_id
319
+
320
+ if best_cluster_id is not None and best_cluster_score > 0.6:
321
+ # Return summary of cluster
322
+ cluster = self.semantic_clusters[best_cluster_id]
323
+ context["long_term"] = cluster[:2] # Top 2 from cluster
324
+
325
+ # 4. Emotional memory patterns
326
+ if current_emotion in self.emotional_memory["trigger_patterns"]:
327
+ triggers = self.emotional_memory["trigger_patterns"][current_emotion]
328
+ context["emotional"] = triggers[-2:] # Recent triggers for this emotion
329
+
330
+ # 5. Conversation themes
331
+ context["themes"] = self.analyze_conversation_themes()
332
+
333
+ return context
334
+
335
+ def update_user_model(self, message, emotion):
336
+ """Build a psychological profile of the user over time"""
337
+ # Communication style detection
338
+ if len(message.split()) > 50:
339
+ style = "detailed"
340
+ elif len(message.split()) < 10:
341
+ style = "concise"
342
+ else:
343
+ style = "moderate"
344
+
345
+ self.user_model["communication_style"] = style
346
+
347
+ # Track recurring concerns
348
+ concern_keywords = {
349
+ "anxiety": ["anxious", "worried", "panic", "nervous", "anxiety"],
350
+ "depression": ["sad", "depressed", "hopeless", "empty", "depression"],
351
+ "relationships": ["partner", "relationship", "friend", "family"],
352
+ "work_stress": ["work", "job", "career", "boss", "stress"]
353
+ }
354
+
355
+ for concern, keywords in concern_keywords.items():
356
+ if any(kw in message.lower() for kw in keywords):
357
+ if concern not in self.user_model["recurring_concerns"]:
358
+ self.user_model["recurring_concerns"].append(concern)
359
+
360
+ def generate_memory_context_string(self, contextual_memory):
361
+ """Format retrieved memories into prompt context"""
362
+ context_parts = []
363
+
364
+ # Working memory (recent conversation)
365
+ if contextual_memory["working"]:
366
+ recent = "\n".join([
367
+ f"User: {m['user']}\nKenko: {m['bot']}"
368
+ for m in contextual_memory["working"]
369
+ ])
370
+ context_parts.append(f"### Recent Conversation:\n{recent}")
371
+
372
+ # Important short-term memories
373
+ if contextual_memory["short_term"]:
374
+ important = "\n".join([m['text'] for m in contextual_memory["short_term"]])
375
+ context_parts.append(f"### Important Recent Context:\n{important}")
376
+
377
+ # Long-term thematic memories
378
+ if contextual_memory["long_term"]:
379
+ longterm = "\n".join([m['text'] for m in contextual_memory["long_term"]])
380
+ context_parts.append(f"### Related Past Discussions:\n{longterm}")
381
+
382
+ # Emotional patterns
383
+ if contextual_memory["emotional"]:
384
+ emotional = ", ".join(contextual_memory["emotional"][:3])
385
+ context_parts.append(f"### Emotional Pattern: Previously triggered by: {emotional}")
386
+
387
+ # Conversation themes
388
+ if contextual_memory["themes"]:
389
+ themes = ", ".join(contextual_memory["themes"])
390
+ context_parts.append(f"### Session Themes: {themes}")
391
+
392
+ # User model insights
393
+ if self.user_model["recurring_concerns"]:
394
+ concerns = ", ".join(self.user_model["recurring_concerns"])
395
+ context_parts.append(f"### Recurring Concerns: {concerns}")
396
+
397
+ return "\n\n".join(context_parts)
398
+
399
+ def reset(self):
400
+ """Reset all memory tiers"""
401
+ self.working_memory = []
402
+ self.short_term_memory = []
403
+ self.semantic_clusters = defaultdict(list)
404
+ self.emotional_memory = {
405
+ "emotion_transitions": [],
406
+ "trigger_patterns": defaultdict(list),
407
+ "coping_effectiveness": {}
408
+ }
409
+ self.conversation_themes = []
410
+ self.user_model = {
411
+ "communication_style": None,
412
+ "recurring_concerns": [],
413
+ "progress_indicators": [],
414
+ "relational_patterns": []
415
+ }
416
+
417
+
418
+ # Initialize the advanced memory system
419
+ print("πŸ”„ Initializing Advanced Memory System...")
420
+ advanced_memory = AdvancedMemorySystem(embedding_model, global_vector_store)
421
+ print("βœ… Advanced Memory System initialized!")
422
+
423
+ # Track previous emotion for transition analysis
424
+ previous_emotion = "neutral"
425
+
426
  def update_emotion_status():
 
427
  if current_emotion_state["last_update"] is None:
428
  return "*Waiting for emotion data...*"
429
 
 
435
  confidence = current_emotion_state["confidence"]
436
  return f"**Current Emotion:** {dominant.capitalize()} ({confidence:.1f}% confidence)\n*Last updated: {int(elapsed)}s ago*"
437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  def analyze_emotion(image):
 
439
  global current_emotion_state
440
 
441
  try:
 
456
  emotions = result['emotion']
457
  dominant = result['dominant_emotion']
458
 
 
459
  current_emotion_state = {
460
  "dominant": dominant,
461
  "confidence": emotions[dominant],
 
463
  "last_update": time.time()
464
  }
465
 
 
466
  output = {}
467
  for emotion, score in sorted(emotions.items(), key=lambda x: x[1], reverse=True):
468
+ output[emotion.capitalize()] = score
469
 
470
  return output
471
 
 
478
  if current_emotion_state["last_update"] is None:
479
  return ""
480
 
 
481
  if time.time() - current_emotion_state["last_update"] > 60:
482
  return ""
483
 
 
490
  def chat_with_kenko(message, history):
491
  """Chat function for Gradio interface with emotion awareness"""
492
 
 
493
  conversation = ""
494
  for user_msg, bot_msg in history:
495
  conversation += f"User: {user_msg}\nKenko: {bot_msg}\n\n"
496
 
 
497
  emotion_context = get_emotion_context()
498
 
 
499
  threat_context = get_threat_context()
500
 
 
501
  prompt = f"""### Instruction:
502
  You are Kenko, a compassionate mental health therapist. Provide empathetic, helpful, and professional responses to support the user's mental wellbeing.
503
  {emotion_context}{threat_context}
 
504
  {conversation}User: {message}
 
505
  ### Response:
506
  """
507
 
 
508
  try:
509
  response = pipe(prompt)[0]['generated_text']
510
  return response.strip()
 
517
 
518
  print(f"[TTS] Generating speech for {len(text)} chars: '{text[:50]}...'")
519
 
 
520
  inputs = tts_processor(text=text, return_tensors="pt", padding=True)
521
  inputs = {k: v.to(tts_device) for k, v in inputs.items()}
522
 
523
  print(f"[TTS] Inputs prepared, generating audio codes...")
524
 
 
525
  with torch.no_grad():
526
  generated_ids = tts_model.generate(**inputs, max_length=2500)
527
 
528
  print(f"[TTS] Audio codes generated, shape: {generated_ids.shape}")
529
  print(f"[TTS] Decoding codes to waveform...")
530
 
 
531
  audio_values = tts_processor.batch_decode(generated_ids, return_tensors="pt")
532
 
 
533
  if isinstance(audio_values, dict) and 'audio_values' in audio_values:
534
  audio_arr = audio_values['audio_values'][0].cpu().numpy()
535
  elif isinstance(audio_values, torch.Tensor):
 
539
  else:
540
  audio_arr = np.array(audio_values).squeeze()
541
 
 
542
  audio_arr = audio_arr.astype(np.float32)
543
 
 
544
  sample_rate = 44100
545
 
546
+ print(f"[TTS] Audio decoded: {len(audio_arr)} samples at {sample_rate}Hz = {len(audio_arr)/sample_rate:.2f} seconds")
547
 
548
  if len(audio_arr) == 0:
549
+ print("Decoded audio is empty!")
550
  return None
551
 
552
  return (sample_rate, audio_arr)
553
 
554
  except Exception as e:
555
+ print(f"TTS generation error: {str(e)}")
556
  import traceback
557
  traceback.print_exc()
558
  return None
559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
  # Custom CSS for a calming interface
561
  css = """
562
  .gradio-container {
 
570
  }
571
  """
572
 
 
573
  with gr.Blocks(
574
  title="Kenko - Mental Health Assistant",
575
  theme=gr.themes.Soft(),
 
577
  ) as demo:
578
 
579
  gr.Markdown("""
580
+ # πŸ’š Kenko - Your Emotion-Aware Mental Health Assistant
 
581
  Welcome! I'm Kenko, an AI mental health therapist enhanced with real-time emotion detection.
582
  Allow webcam access to enable emotion-aware responses that adapt to how you're feeling.
 
583
  *Please remember: I'm an AI assistant and cannot replace professional mental health care. In crisis situations, please contact emergency services or a mental health professional.*
584
  """)
585
 
586
  with gr.Row():
 
587
  with gr.Column(scale=2):
588
  chatbot = gr.Chatbot(
589
  height=500,
 
613
  clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", scale=1, variant="secondary")
614
  examples_btn = gr.Button("πŸ’‘ Example Topics", scale=1, variant="secondary")
615
 
 
616
  with gr.Column(scale=1):
617
  gr.Markdown("### πŸ“Έ Emotion Detection")
618
  gr.Markdown("*Your emotional state helps me provide more personalized support*")
 
631
 
632
  emotion_status = gr.Markdown("*Waiting for emotion data...*")
633
 
634
+
 
 
 
 
 
 
 
 
 
 
635
  with gr.Row(visible=False) as examples_row:
636
  gr.Examples(
637
  examples=[
 
657
  - Relationship and communication advice
658
  - Mindfulness and self-care suggestions
659
  - Building healthy habits and routines
 
660
  **Emotion Detection Feature:**
661
  - Real-time facial emotion analysis
662
  - Adapts responses based on your current emotional state
663
  - Updates automatically every 30 seconds
664
  - Completely optional - works without webcam too
 
665
  **Important Notes:**
666
  - I'm an AI trained to provide mental health support
667
  - For immediate crisis support, contact emergency services (911) or crisis hotlines
668
  - Consider professional therapy for ongoing mental health needs
669
  - I don't diagnose conditions or prescribe medications
 
670
  **Privacy:** Your conversations and emotion data are not stored or shared.
671
  """)
672
  @spaces.GPU
 
695
  return gr.Row(visible=True)
696
 
697
 
698
+
699
  submit = msg.submit(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot, audio_output])
700
  send = send_btn.click(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot, audio_output])
701
  clear_btn.click(lambda: [], None, outputs=[chatbot, audio_output])
702
  examples_btn.click(toggle_examples, outputs=examples_row)
703
 
704
+
705
  webcam_input.stream(
706
  analyze_emotion,
707
  inputs=webcam_input,
708
  outputs=emotion_output,
709
+ stream_every=1,
710
+ time_limit=60
711
  )
712
 
713
+ timer = gr.Timer(value=5)
714
+
715
 
 
 
 
 
 
 
 
 
 
716
  timer.tick(
717
+ fn=lambda: (update_emotion_status()),
718
+ outputs=[emotion_status]
719
  )
720
 
721