Update app.py
Browse files
app.py
CHANGED
|
@@ -1,26 +1,22 @@
|
|
| 1 |
-
#INFERENCE NLP+EMOTION DETECTION CV+TTS+
|
| 2 |
import spaces
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 6 |
import torch
|
| 7 |
from deepface import DeepFace
|
| 8 |
-
import threading
|
| 9 |
import time
|
| 10 |
from transformers import AutoProcessor, DiaForConditionalGeneration
|
| 11 |
import numpy as np
|
| 12 |
-
import
|
| 13 |
-
import
|
| 14 |
-
from
|
| 15 |
-
import
|
| 16 |
-
from rfdetr import RFDETRNano
|
| 17 |
|
| 18 |
-
# Model setup
|
| 19 |
model_name = "IniNLP247/Kenko-mental-health-llama-3-model"
|
| 20 |
|
| 21 |
-
print("
|
| 22 |
|
| 23 |
-
# Load tokenizer and model
|
| 24 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 25 |
if tokenizer.pad_token is None:
|
| 26 |
tokenizer.pad_token = tokenizer.eos_token
|
|
@@ -31,7 +27,6 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 31 |
device_map="auto"
|
| 32 |
)
|
| 33 |
|
| 34 |
-
# Create pipeline for easier inference
|
| 35 |
pipe = pipeline(
|
| 36 |
"text-generation",
|
| 37 |
model=model,
|
|
@@ -44,51 +39,31 @@ pipe = pipeline(
|
|
| 44 |
pad_token_id=tokenizer.pad_token_id
|
| 45 |
)
|
| 46 |
|
| 47 |
-
print("
|
|
|
|
| 48 |
|
| 49 |
-
#Loading of TTS
|
| 50 |
print("Loading Dia TTS Model...")
|
| 51 |
tts_device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 52 |
tts_model = "nari-labs/Dia-1.6B-0626"
|
| 53 |
tts_processor = AutoProcessor.from_pretrained(tts_model)
|
| 54 |
tts_model = DiaForConditionalGeneration.from_pretrained(tts_model, torch_dtype=torch.float16).to(tts_device)
|
| 55 |
-
print("
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
3: "Grenade",
|
| 61 |
-
4: "Knife"
|
| 62 |
-
}
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
with open(threat_weights_filename, 'wb') as f:
|
| 75 |
-
for chunk in response.iter_content(chunk_size=8192):
|
| 76 |
-
f.write(chunk)
|
| 77 |
-
print("Download complete.")
|
| 78 |
-
|
| 79 |
-
threat_model = RFDETRNano(resolution=640, pretrain_weights=threat_weights_filename)
|
| 80 |
-
#threat_model.optimize_for_inference()
|
| 81 |
-
|
| 82 |
-
print("β
Threat Detection Model loaded successfully!")
|
| 83 |
-
|
| 84 |
-
#Global Variables For Threat Detection
|
| 85 |
-
current_threat_state = {
|
| 86 |
-
"threat_detected": [],
|
| 87 |
-
"threat_count": 0,
|
| 88 |
-
"last_update": None
|
| 89 |
-
}
|
| 90 |
|
| 91 |
-
# Global variable to store current emotion state
|
| 92 |
current_emotion_state = {
|
| 93 |
"dominant": "neutral",
|
| 94 |
"confidence": 0.0,
|
|
@@ -96,8 +71,359 @@ current_emotion_state = {
|
|
| 96 |
"last_update": None
|
| 97 |
}
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
def update_emotion_status():
|
| 100 |
-
"""Update emotion status text"""
|
| 101 |
if current_emotion_state["last_update"] is None:
|
| 102 |
return "*Waiting for emotion data...*"
|
| 103 |
|
|
@@ -109,22 +435,7 @@ def update_emotion_status():
|
|
| 109 |
confidence = current_emotion_state["confidence"]
|
| 110 |
return f"**Current Emotion:** {dominant.capitalize()} ({confidence:.1f}% confidence)\n*Last updated: {int(elapsed)}s ago*"
|
| 111 |
|
| 112 |
-
def update_threat_status():
|
| 113 |
-
"""Update threat status text"""
|
| 114 |
-
if current_threat_state["last_update"] is None:
|
| 115 |
-
return "*Monitoring for threats...*"
|
| 116 |
-
|
| 117 |
-
elapsed = time.time() - current_threat_state["last_update"]
|
| 118 |
-
|
| 119 |
-
threats = current_threat_state["threat_detected"]
|
| 120 |
-
if threats:
|
| 121 |
-
threat_list = ", ".join([t["type"] for t in threats])
|
| 122 |
-
return f"**β οΈ ALERT:** {threat_list} detected\n*Last updated: {int(elapsed)}s ago*"
|
| 123 |
-
else:
|
| 124 |
-
return f"**β
Safe:** No threats detected\n*Last updated: {int(elapsed)}s ago*"
|
| 125 |
-
|
| 126 |
def analyze_emotion(image):
|
| 127 |
-
"""Analyze emotion from webcam image"""
|
| 128 |
global current_emotion_state
|
| 129 |
|
| 130 |
try:
|
|
@@ -145,7 +456,6 @@ def analyze_emotion(image):
|
|
| 145 |
emotions = result['emotion']
|
| 146 |
dominant = result['dominant_emotion']
|
| 147 |
|
| 148 |
-
# Update global emotion state
|
| 149 |
current_emotion_state = {
|
| 150 |
"dominant": dominant,
|
| 151 |
"confidence": emotions[dominant],
|
|
@@ -153,10 +463,9 @@ def analyze_emotion(image):
|
|
| 153 |
"last_update": time.time()
|
| 154 |
}
|
| 155 |
|
| 156 |
-
# Format for display - REMOVE the % symbol and keep as numbers
|
| 157 |
output = {}
|
| 158 |
for emotion, score in sorted(emotions.items(), key=lambda x: x[1], reverse=True):
|
| 159 |
-
output[emotion.capitalize()] = score
|
| 160 |
|
| 161 |
return output
|
| 162 |
|
|
@@ -169,7 +478,6 @@ def get_emotion_context():
|
|
| 169 |
if current_emotion_state["last_update"] is None:
|
| 170 |
return ""
|
| 171 |
|
| 172 |
-
# Check if emotion data is recent (within last 60 seconds)
|
| 173 |
if time.time() - current_emotion_state["last_update"] > 60:
|
| 174 |
return ""
|
| 175 |
|
|
@@ -182,28 +490,21 @@ def get_emotion_context():
|
|
| 182 |
def chat_with_kenko(message, history):
|
| 183 |
"""Chat function for Gradio interface with emotion awareness"""
|
| 184 |
|
| 185 |
-
# Build conversation context
|
| 186 |
conversation = ""
|
| 187 |
for user_msg, bot_msg in history:
|
| 188 |
conversation += f"User: {user_msg}\nKenko: {bot_msg}\n\n"
|
| 189 |
|
| 190 |
-
# Get emotion context
|
| 191 |
emotion_context = get_emotion_context()
|
| 192 |
|
| 193 |
-
# Get threat context
|
| 194 |
threat_context = get_threat_context()
|
| 195 |
|
| 196 |
-
# Create prompt in instruction format with emotion awareness
|
| 197 |
prompt = f"""### Instruction:
|
| 198 |
You are Kenko, a compassionate mental health therapist. Provide empathetic, helpful, and professional responses to support the user's mental wellbeing.
|
| 199 |
{emotion_context}{threat_context}
|
| 200 |
-
|
| 201 |
{conversation}User: {message}
|
| 202 |
-
|
| 203 |
### Response:
|
| 204 |
"""
|
| 205 |
|
| 206 |
-
# Generate response
|
| 207 |
try:
|
| 208 |
response = pipe(prompt)[0]['generated_text']
|
| 209 |
return response.strip()
|
|
@@ -216,23 +517,19 @@ def generate_tts(text):
|
|
| 216 |
|
| 217 |
print(f"[TTS] Generating speech for {len(text)} chars: '{text[:50]}...'")
|
| 218 |
|
| 219 |
-
# Prepare inputs for Dia TTS
|
| 220 |
inputs = tts_processor(text=text, return_tensors="pt", padding=True)
|
| 221 |
inputs = {k: v.to(tts_device) for k, v in inputs.items()}
|
| 222 |
|
| 223 |
print(f"[TTS] Inputs prepared, generating audio codes...")
|
| 224 |
|
| 225 |
-
# Generate audio codes
|
| 226 |
with torch.no_grad():
|
| 227 |
generated_ids = tts_model.generate(**inputs, max_length=2500)
|
| 228 |
|
| 229 |
print(f"[TTS] Audio codes generated, shape: {generated_ids.shape}")
|
| 230 |
print(f"[TTS] Decoding codes to waveform...")
|
| 231 |
|
| 232 |
-
# Decode the audio codes to waveform using the processor's batch_decode
|
| 233 |
audio_values = tts_processor.batch_decode(generated_ids, return_tensors="pt")
|
| 234 |
|
| 235 |
-
# Extract the audio waveform
|
| 236 |
if isinstance(audio_values, dict) and 'audio_values' in audio_values:
|
| 237 |
audio_arr = audio_values['audio_values'][0].cpu().numpy()
|
| 238 |
elif isinstance(audio_values, torch.Tensor):
|
|
@@ -242,93 +539,24 @@ def generate_tts(text):
|
|
| 242 |
else:
|
| 243 |
audio_arr = np.array(audio_values).squeeze()
|
| 244 |
|
| 245 |
-
# Ensure float32
|
| 246 |
audio_arr = audio_arr.astype(np.float32)
|
| 247 |
|
| 248 |
-
# Dia uses 44.1kHz sample rate
|
| 249 |
sample_rate = 44100
|
| 250 |
|
| 251 |
-
print(f"
|
| 252 |
|
| 253 |
if len(audio_arr) == 0:
|
| 254 |
-
print("
|
| 255 |
return None
|
| 256 |
|
| 257 |
return (sample_rate, audio_arr)
|
| 258 |
|
| 259 |
except Exception as e:
|
| 260 |
-
print(f"
|
| 261 |
import traceback
|
| 262 |
traceback.print_exc()
|
| 263 |
return None
|
| 264 |
|
| 265 |
-
def threat_detection(image):
|
| 266 |
-
"""Threat detection function for webcam"""
|
| 267 |
-
global current_threat_state
|
| 268 |
-
|
| 269 |
-
try:
|
| 270 |
-
if image is None:
|
| 271 |
-
return {}
|
| 272 |
-
|
| 273 |
-
# Convert numpy array to PIL Image if needed
|
| 274 |
-
if isinstance(image, np.ndarray):
|
| 275 |
-
image = Image.fromarray(image)
|
| 276 |
-
|
| 277 |
-
# Run Threat Detection
|
| 278 |
-
detections = threat_model.predict(image, threshold=0.3) # Lower threshold for testing
|
| 279 |
-
|
| 280 |
-
# Parse detections - detections.class_id and detections.confidence are ARRAYS
|
| 281 |
-
threat_found = []
|
| 282 |
-
if detections is not None and len(detections.class_id) > 0:
|
| 283 |
-
for class_id, confidence in zip(detections.class_id, detections.confidence):
|
| 284 |
-
class_id = int(class_id)
|
| 285 |
-
confidence = float(confidence)
|
| 286 |
-
|
| 287 |
-
if class_id in THREAT_CLASSES:
|
| 288 |
-
threat_name = THREAT_CLASSES[class_id]
|
| 289 |
-
threat_found.append({"type": threat_name, "confidence": confidence})
|
| 290 |
-
print(f"π¨ THREAT DETECTED: {threat_name} - {confidence:.2%}")
|
| 291 |
-
|
| 292 |
-
# Update global threat state
|
| 293 |
-
current_threat_state = {
|
| 294 |
-
"threat_detected": threat_found,
|
| 295 |
-
"threat_count": len(threat_found),
|
| 296 |
-
"last_update": time.time()
|
| 297 |
-
}
|
| 298 |
-
|
| 299 |
-
# Format for display
|
| 300 |
-
if threat_found:
|
| 301 |
-
output = {}
|
| 302 |
-
for threat in threat_found:
|
| 303 |
-
output[threat["type"]] = threat["confidence"] * 100
|
| 304 |
-
return output
|
| 305 |
-
else:
|
| 306 |
-
return {"No threats detected": 100.0}
|
| 307 |
-
|
| 308 |
-
except Exception as e:
|
| 309 |
-
print(f"Threat detection error: {str(e)}")
|
| 310 |
-
import traceback
|
| 311 |
-
traceback.print_exc()
|
| 312 |
-
return {}
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
def get_threat_context():
|
| 316 |
-
"""Get current threat as context string for the model"""
|
| 317 |
-
if current_threat_state["last_update"] is None:
|
| 318 |
-
return ""
|
| 319 |
-
|
| 320 |
-
#Check if threat data is recent (within last 60 seconds)
|
| 321 |
-
if time.time() - current_threat_state["last_update"] > 60:
|
| 322 |
-
return ""
|
| 323 |
-
|
| 324 |
-
threats = current_threat_state["threat_detected"]
|
| 325 |
-
|
| 326 |
-
if threats:
|
| 327 |
-
threat_list = ", ".join([f"{t['type']} ({t['confidence']*100:.1f}% confidence)" for t in threats])
|
| 328 |
-
return f"\n[User currently holds a potential threat: {threat_list}]"
|
| 329 |
-
|
| 330 |
-
return ""
|
| 331 |
-
|
| 332 |
# Custom CSS for a calming interface
|
| 333 |
css = """
|
| 334 |
.gradio-container {
|
|
@@ -342,7 +570,6 @@ css = """
|
|
| 342 |
}
|
| 343 |
"""
|
| 344 |
|
| 345 |
-
# Create Gradio interface
|
| 346 |
with gr.Blocks(
|
| 347 |
title="Kenko - Mental Health Assistant",
|
| 348 |
theme=gr.themes.Soft(),
|
|
@@ -350,16 +577,13 @@ with gr.Blocks(
|
|
| 350 |
) as demo:
|
| 351 |
|
| 352 |
gr.Markdown("""
|
| 353 |
-
#
|
| 354 |
-
|
| 355 |
Welcome! I'm Kenko, an AI mental health therapist enhanced with real-time emotion detection.
|
| 356 |
Allow webcam access to enable emotion-aware responses that adapt to how you're feeling.
|
| 357 |
-
|
| 358 |
*Please remember: I'm an AI assistant and cannot replace professional mental health care. In crisis situations, please contact emergency services or a mental health professional.*
|
| 359 |
""")
|
| 360 |
|
| 361 |
with gr.Row():
|
| 362 |
-
# Left column: Chat interface
|
| 363 |
with gr.Column(scale=2):
|
| 364 |
chatbot = gr.Chatbot(
|
| 365 |
height=500,
|
|
@@ -389,7 +613,6 @@ with gr.Blocks(
|
|
| 389 |
clear_btn = gr.Button("ποΈ Clear Chat", scale=1, variant="secondary")
|
| 390 |
examples_btn = gr.Button("π‘ Example Topics", scale=1, variant="secondary")
|
| 391 |
|
| 392 |
-
# Right column: Emotion detection
|
| 393 |
with gr.Column(scale=1):
|
| 394 |
gr.Markdown("### πΈ Emotion Detection")
|
| 395 |
gr.Markdown("*Your emotional state helps me provide more personalized support*")
|
|
@@ -408,17 +631,7 @@ with gr.Blocks(
|
|
| 408 |
|
| 409 |
emotion_status = gr.Markdown("*Waiting for emotion data...*")
|
| 410 |
|
| 411 |
-
|
| 412 |
-
gr.Markdown("### Threat Detection")
|
| 413 |
-
threat_output = gr.Label(
|
| 414 |
-
num_top_classes=4,
|
| 415 |
-
label="Detected Threats"
|
| 416 |
-
)
|
| 417 |
-
|
| 418 |
-
threat_status = gr.Markdown("*Monitoring for threats...")
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
# Example prompts
|
| 422 |
with gr.Row(visible=False) as examples_row:
|
| 423 |
gr.Examples(
|
| 424 |
examples=[
|
|
@@ -444,19 +657,16 @@ with gr.Blocks(
|
|
| 444 |
- Relationship and communication advice
|
| 445 |
- Mindfulness and self-care suggestions
|
| 446 |
- Building healthy habits and routines
|
| 447 |
-
|
| 448 |
**Emotion Detection Feature:**
|
| 449 |
- Real-time facial emotion analysis
|
| 450 |
- Adapts responses based on your current emotional state
|
| 451 |
- Updates automatically every 30 seconds
|
| 452 |
- Completely optional - works without webcam too
|
| 453 |
-
|
| 454 |
**Important Notes:**
|
| 455 |
- I'm an AI trained to provide mental health support
|
| 456 |
- For immediate crisis support, contact emergency services (911) or crisis hotlines
|
| 457 |
- Consider professional therapy for ongoing mental health needs
|
| 458 |
- I don't diagnose conditions or prescribe medications
|
| 459 |
-
|
| 460 |
**Privacy:** Your conversations and emotion data are not stored or shared.
|
| 461 |
""")
|
| 462 |
@spaces.GPU
|
|
@@ -485,35 +695,27 @@ with gr.Blocks(
|
|
| 485 |
return gr.Row(visible=True)
|
| 486 |
|
| 487 |
|
| 488 |
-
|
| 489 |
submit = msg.submit(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot, audio_output])
|
| 490 |
send = send_btn.click(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot, audio_output])
|
| 491 |
clear_btn.click(lambda: [], None, outputs=[chatbot, audio_output])
|
| 492 |
examples_btn.click(toggle_examples, outputs=examples_row)
|
| 493 |
|
| 494 |
-
|
| 495 |
webcam_input.stream(
|
| 496 |
analyze_emotion,
|
| 497 |
inputs=webcam_input,
|
| 498 |
outputs=emotion_output,
|
| 499 |
-
stream_every=1,
|
| 500 |
-
time_limit=60
|
| 501 |
)
|
| 502 |
|
| 503 |
-
timer = gr.Timer(value=5)
|
|
|
|
| 504 |
|
| 505 |
-
# Threat detection with streaming
|
| 506 |
-
webcam_input.stream(
|
| 507 |
-
threat_detection, # Corrected function name
|
| 508 |
-
inputs=webcam_input, # Corrected inputs
|
| 509 |
-
outputs=threat_output,
|
| 510 |
-
stream_every=2,
|
| 511 |
-
time_limit=60
|
| 512 |
-
)
|
| 513 |
-
# Add to timer tick
|
| 514 |
timer.tick(
|
| 515 |
-
fn=lambda: (update_emotion_status()
|
| 516 |
-
outputs=[emotion_status
|
| 517 |
)
|
| 518 |
|
| 519 |
|
|
|
|
| 1 |
+
#INFERENCE NLP+EMOTION DETECTION CV+TTS+Memory Management
|
| 2 |
import spaces
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 6 |
import torch
|
| 7 |
from deepface import DeepFace
|
|
|
|
| 8 |
import time
|
| 9 |
from transformers import AutoProcessor, DiaForConditionalGeneration
|
| 10 |
import numpy as np
|
| 11 |
+
import chromadb
|
| 12 |
+
from langchain_community.vectorstores import Chroma
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
from sklearn.cluster import DBSCAN
|
|
|
|
| 15 |
|
|
|
|
| 16 |
model_name = "IniNLP247/Kenko-mental-health-llama-3-model"
|
| 17 |
|
| 18 |
+
print("Loading Kenko Mental Health Model...")
|
| 19 |
|
|
|
|
| 20 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 21 |
if tokenizer.pad_token is None:
|
| 22 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
| 27 |
device_map="auto"
|
| 28 |
)
|
| 29 |
|
|
|
|
| 30 |
pipe = pipeline(
|
| 31 |
"text-generation",
|
| 32 |
model=model,
|
|
|
|
| 39 |
pad_token_id=tokenizer.pad_token_id
|
| 40 |
)
|
| 41 |
|
| 42 |
+
print("Model loaded successfully!")
|
| 43 |
+
|
| 44 |
|
|
|
|
| 45 |
print("Loading Dia TTS Model...")
|
| 46 |
tts_device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 47 |
tts_model = "nari-labs/Dia-1.6B-0626"
|
| 48 |
tts_processor = AutoProcessor.from_pretrained(tts_model)
|
| 49 |
tts_model = DiaForConditionalGeneration.from_pretrained(tts_model, torch_dtype=torch.float16).to(tts_device)
|
| 50 |
+
print("Dia TTS Model loaded successfully!")
|
| 51 |
|
| 52 |
+
print("Initializing Memory Components...")
|
| 53 |
+
chroma_client = chromadb.Client()
|
| 54 |
+
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
def embed_function(texts):
|
| 57 |
+
if isinstance(texts, str):
|
| 58 |
+
texts = [texts]
|
| 59 |
+
return embedding_model.encode(texts).tolist()
|
| 60 |
+
|
| 61 |
+
global_vector_store = Chroma(
|
| 62 |
+
client=chroma_client,
|
| 63 |
+
embedding_function=embed_function
|
| 64 |
+
)
|
| 65 |
+
print("Memory components initialized!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
|
|
|
| 67 |
current_emotion_state = {
|
| 68 |
"dominant": "neutral",
|
| 69 |
"confidence": 0.0,
|
|
|
|
| 71 |
"last_update": None
|
| 72 |
}
|
| 73 |
|
| 74 |
+
|
| 75 |
+
class AdvancedMemorySystem:
|
| 76 |
+
"""
|
| 77 |
+
Multi-tier memory system inspired by human memory:
|
| 78 |
+
- Working Memory: Current conversation (high priority)
|
| 79 |
+
- Short-term Memory: Recent session with decay
|
| 80 |
+
- Long-term Memory: Semantic clusters of important themes
|
| 81 |
+
- Emotional Memory: Affective associations and patterns
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(self, embedding_model, vector_store):
|
| 85 |
+
self.embedding_model = embedding_model
|
| 86 |
+
self.vector_store = vector_store
|
| 87 |
+
|
| 88 |
+
# Working memory - immediate context
|
| 89 |
+
self.working_memory = []
|
| 90 |
+
|
| 91 |
+
# Short-term memory with temporal decay
|
| 92 |
+
self.short_term_memory = [] # (timestamp, text, embedding, importance_score)
|
| 93 |
+
|
| 94 |
+
# Long-term semantic clusters
|
| 95 |
+
self.semantic_clusters = defaultdict(list) # {cluster_id: [memories]}
|
| 96 |
+
|
| 97 |
+
# Emotional memory graph
|
| 98 |
+
self.emotional_memory = {
|
| 99 |
+
"emotion_transitions": [], # Track emotional journey
|
| 100 |
+
"trigger_patterns": defaultdict(list), # What triggers what emotion
|
| 101 |
+
"coping_effectiveness": {} # Track which strategies work
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
# Meta-cognitive tracking
|
| 105 |
+
self.conversation_themes = []
|
| 106 |
+
self.user_model = {
|
| 107 |
+
"communication_style": None,
|
| 108 |
+
"recurring_concerns": [],
|
| 109 |
+
"progress_indicators": [],
|
| 110 |
+
"relational_patterns": []
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
def calculate_importance(self, text, emotion, user_engagement):
|
| 114 |
+
"""Calculate memory importance using multiple factors"""
|
| 115 |
+
importance = 0.5 # Base importance
|
| 116 |
+
|
| 117 |
+
# Emotional weight
|
| 118 |
+
high_intensity_emotions = ["fear", "angry", "sad", "surprise"]
|
| 119 |
+
if emotion in high_intensity_emotions:
|
| 120 |
+
importance += 0.3
|
| 121 |
+
|
| 122 |
+
# Engagement weight
|
| 123 |
+
if len(text.split()) > 30: # Longer, detailed message
|
| 124 |
+
importance += 0.2
|
| 125 |
+
|
| 126 |
+
# Therapeutic keywords
|
| 127 |
+
therapeutic_keywords = [
|
| 128 |
+
"trauma", "suicide", "self-harm", "abuse", "panic",
|
| 129 |
+
"breakthrough", "progress", "better", "worse", "relationship"
|
| 130 |
+
]
|
| 131 |
+
if any(kw in text.lower() for kw in therapeutic_keywords):
|
| 132 |
+
importance += 0.3
|
| 133 |
+
|
| 134 |
+
return min(importance, 1.0)
|
| 135 |
+
|
| 136 |
+
def add_to_working_memory(self, user_msg, bot_msg, emotion, timestamp):
|
| 137 |
+
"""Add to immediate working memory (sliding window)"""
|
| 138 |
+
self.working_memory.append({
|
| 139 |
+
"user": user_msg,
|
| 140 |
+
"bot": bot_msg,
|
| 141 |
+
"emotion": emotion,
|
| 142 |
+
"timestamp": timestamp
|
| 143 |
+
})
|
| 144 |
+
|
| 145 |
+
# Keep only last 5 exchanges in working memory
|
| 146 |
+
if len(self.working_memory) > 5:
|
| 147 |
+
# Consolidate oldest to short-term before removing
|
| 148 |
+
oldest = self.working_memory.pop(0)
|
| 149 |
+
self._consolidate_to_short_term(oldest)
|
| 150 |
+
|
| 151 |
+
def _consolidate_to_short_term(self, memory_item):
|
| 152 |
+
"""Move from working to short-term memory with importance scoring"""
|
| 153 |
+
text = f"User: {memory_item['user']}\nKenko: {memory_item['bot']}"
|
| 154 |
+
embedding = self.embedding_model.encode(text)
|
| 155 |
+
|
| 156 |
+
# Calculate importance
|
| 157 |
+
importance = self.calculate_importance(
|
| 158 |
+
memory_item['user'],
|
| 159 |
+
memory_item['emotion'],
|
| 160 |
+
len(memory_item['user'].split())
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
self.short_term_memory.append({
|
| 164 |
+
"text": text,
|
| 165 |
+
"embedding": embedding,
|
| 166 |
+
"importance": importance,
|
| 167 |
+
"timestamp": memory_item['timestamp'],
|
| 168 |
+
"emotion": memory_item['emotion']
|
| 169 |
+
})
|
| 170 |
+
|
| 171 |
+
# Add to vector store with importance weighting
|
| 172 |
+
try:
|
| 173 |
+
self.vector_store.add_texts(
|
| 174 |
+
texts=[text],
|
| 175 |
+
metadatas=[{"importance": importance, "timestamp": memory_item['timestamp']}]
|
| 176 |
+
)
|
| 177 |
+
except Exception as e:
|
| 178 |
+
print(f"Vector store error: {e}")
|
| 179 |
+
|
| 180 |
+
def apply_temporal_decay(self, current_time):
|
| 181 |
+
"""Apply decay to short-term memories over time"""
|
| 182 |
+
decay_rate = 0.01 # Decay 1% per time unit (adjusted for slower decay)
|
| 183 |
+
|
| 184 |
+
for memory in self.short_term_memory:
|
| 185 |
+
time_elapsed = (current_time - memory['timestamp']) / 60 # Convert to minutes
|
| 186 |
+
decay_factor = np.exp(-decay_rate * time_elapsed)
|
| 187 |
+
memory['importance'] *= decay_factor
|
| 188 |
+
|
| 189 |
+
# If importance drops below threshold, cluster into long-term
|
| 190 |
+
if memory['importance'] < 0.15:
|
| 191 |
+
self._consolidate_to_long_term(memory)
|
| 192 |
+
|
| 193 |
+
def _consolidate_to_long_term(self, memory):
|
| 194 |
+
"""Cluster similar memories into semantic long-term memory"""
|
| 195 |
+
# Get embeddings of all long-term memories
|
| 196 |
+
if not self.semantic_clusters:
|
| 197 |
+
self.semantic_clusters[0] = [memory]
|
| 198 |
+
self.short_term_memory.remove(memory)
|
| 199 |
+
return
|
| 200 |
+
|
| 201 |
+
# Find semantic cluster using cosine similarity
|
| 202 |
+
best_cluster = 0
|
| 203 |
+
best_similarity = -1
|
| 204 |
+
|
| 205 |
+
for cluster_id, cluster_memories in self.semantic_clusters.items():
|
| 206 |
+
# Compare with cluster centroid
|
| 207 |
+
cluster_embeddings = [m['embedding'] for m in cluster_memories]
|
| 208 |
+
centroid = np.mean(cluster_embeddings, axis=0)
|
| 209 |
+
|
| 210 |
+
similarity = np.dot(memory['embedding'], centroid) / (
|
| 211 |
+
np.linalg.norm(memory['embedding']) * np.linalg.norm(centroid)
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
if similarity > best_similarity:
|
| 215 |
+
best_similarity = similarity
|
| 216 |
+
best_cluster = cluster_id
|
| 217 |
+
|
| 218 |
+
# Add to cluster if similar enough, else create new cluster
|
| 219 |
+
if best_similarity > 0.7:
|
| 220 |
+
self.semantic_clusters[best_cluster].append(memory)
|
| 221 |
+
else:
|
| 222 |
+
new_cluster_id = max(self.semantic_clusters.keys()) + 1
|
| 223 |
+
self.semantic_clusters[new_cluster_id] = [memory]
|
| 224 |
+
|
| 225 |
+
# Remove from short-term
|
| 226 |
+
if memory in self.short_term_memory:
|
| 227 |
+
self.short_term_memory.remove(memory)
|
| 228 |
+
|
| 229 |
+
def track_emotional_transition(self, prev_emotion, current_emotion, context):
|
| 230 |
+
"""Track emotional state transitions for pattern recognition"""
|
| 231 |
+
self.emotional_memory["emotion_transitions"].append({
|
| 232 |
+
"from": prev_emotion,
|
| 233 |
+
"to": current_emotion,
|
| 234 |
+
"context": context,
|
| 235 |
+
"timestamp": time.time()
|
| 236 |
+
})
|
| 237 |
+
|
| 238 |
+
# Analyze if certain topics trigger emotional shifts
|
| 239 |
+
if prev_emotion != current_emotion:
|
| 240 |
+
self.emotional_memory["trigger_patterns"][current_emotion].append(context)
|
| 241 |
+
|
| 242 |
+
def analyze_conversation_themes(self):
|
| 243 |
+
"""Use topic modeling on conversation to identify recurring themes"""
|
| 244 |
+
if len(self.short_term_memory) < 3:
|
| 245 |
+
return []
|
| 246 |
+
|
| 247 |
+
# Simple keyword extraction
|
| 248 |
+
all_text = " ".join([m['text'] for m in self.short_term_memory])
|
| 249 |
+
|
| 250 |
+
# Extract key phrases
|
| 251 |
+
words = all_text.lower().split()
|
| 252 |
+
word_freq = defaultdict(int)
|
| 253 |
+
|
| 254 |
+
stopwords = {"the", "a", "is", "in", "and", "to", "of", "i", "my", "me", "you", "that", "it"}
|
| 255 |
+
for word in words:
|
| 256 |
+
if word not in stopwords and len(word) > 4:
|
| 257 |
+
word_freq[word] += 1
|
| 258 |
+
|
| 259 |
+
# Get top themes
|
| 260 |
+
themes = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)[:5]
|
| 261 |
+
self.conversation_themes = [theme[0] for theme in themes]
|
| 262 |
+
|
| 263 |
+
return self.conversation_themes
|
| 264 |
+
|
| 265 |
+
def retrieve_contextual_memory(self, query, current_emotion):
|
| 266 |
+
"""Advanced retrieval using multiple memory tiers"""
|
| 267 |
+
context = {
|
| 268 |
+
"working": [],
|
| 269 |
+
"short_term": [],
|
| 270 |
+
"long_term": [],
|
| 271 |
+
"emotional": [],
|
| 272 |
+
"themes": []
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
# 1. Working memory (full recent conversation)
|
| 276 |
+
context["working"] = self.working_memory[-3:] # Last 3 exchanges
|
| 277 |
+
|
| 278 |
+
# 2. Short-term memory (importance-weighted retrieval)
|
| 279 |
+
if self.short_term_memory:
|
| 280 |
+
query_embedding = self.embedding_model.encode(query)
|
| 281 |
+
|
| 282 |
+
scored_memories = []
|
| 283 |
+
for memory in self.short_term_memory:
|
| 284 |
+
# Semantic similarity
|
| 285 |
+
similarity = np.dot(query_embedding, memory['embedding']) / (
|
| 286 |
+
np.linalg.norm(query_embedding) * np.linalg.norm(memory['embedding'])
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Boost by importance score
|
| 290 |
+
final_score = similarity * memory['importance']
|
| 291 |
+
|
| 292 |
+
# Emotional congruence bonus
|
| 293 |
+
if memory['emotion'] == current_emotion:
|
| 294 |
+
final_score *= 1.2
|
| 295 |
+
|
| 296 |
+
scored_memories.append((final_score, memory))
|
| 297 |
+
|
| 298 |
+
# Get top 3 short-term memories
|
| 299 |
+
scored_memories.sort(reverse=True, key=lambda x: x[0])
|
| 300 |
+
context["short_term"] = [m[1] for m in scored_memories[:3]]
|
| 301 |
+
|
| 302 |
+
# 3. Long-term memory (semantic clusters)
|
| 303 |
+
if self.semantic_clusters:
|
| 304 |
+
query_embedding = self.embedding_model.encode(query)
|
| 305 |
+
best_cluster_id = None
|
| 306 |
+
best_cluster_score = -1
|
| 307 |
+
|
| 308 |
+
for cluster_id, cluster_memories in self.semantic_clusters.items():
|
| 309 |
+
cluster_embeddings = [m['embedding'] for m in cluster_memories]
|
| 310 |
+
centroid = np.mean(cluster_embeddings, axis=0)
|
| 311 |
+
|
| 312 |
+
similarity = np.dot(query_embedding, centroid) / (
|
| 313 |
+
np.linalg.norm(query_embedding) * np.linalg.norm(centroid)
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
if similarity > best_cluster_score:
|
| 317 |
+
best_cluster_score = similarity
|
| 318 |
+
best_cluster_id = cluster_id
|
| 319 |
+
|
| 320 |
+
if best_cluster_id is not None and best_cluster_score > 0.6:
|
| 321 |
+
# Return summary of cluster
|
| 322 |
+
cluster = self.semantic_clusters[best_cluster_id]
|
| 323 |
+
context["long_term"] = cluster[:2] # Top 2 from cluster
|
| 324 |
+
|
| 325 |
+
# 4. Emotional memory patterns
|
| 326 |
+
if current_emotion in self.emotional_memory["trigger_patterns"]:
|
| 327 |
+
triggers = self.emotional_memory["trigger_patterns"][current_emotion]
|
| 328 |
+
context["emotional"] = triggers[-2:] # Recent triggers for this emotion
|
| 329 |
+
|
| 330 |
+
# 5. Conversation themes
|
| 331 |
+
context["themes"] = self.analyze_conversation_themes()
|
| 332 |
+
|
| 333 |
+
return context
|
| 334 |
+
|
| 335 |
+
def update_user_model(self, message, emotion):
|
| 336 |
+
"""Build a psychological profile of the user over time"""
|
| 337 |
+
# Communication style detection
|
| 338 |
+
if len(message.split()) > 50:
|
| 339 |
+
style = "detailed"
|
| 340 |
+
elif len(message.split()) < 10:
|
| 341 |
+
style = "concise"
|
| 342 |
+
else:
|
| 343 |
+
style = "moderate"
|
| 344 |
+
|
| 345 |
+
self.user_model["communication_style"] = style
|
| 346 |
+
|
| 347 |
+
# Track recurring concerns
|
| 348 |
+
concern_keywords = {
|
| 349 |
+
"anxiety": ["anxious", "worried", "panic", "nervous", "anxiety"],
|
| 350 |
+
"depression": ["sad", "depressed", "hopeless", "empty", "depression"],
|
| 351 |
+
"relationships": ["partner", "relationship", "friend", "family"],
|
| 352 |
+
"work_stress": ["work", "job", "career", "boss", "stress"]
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
for concern, keywords in concern_keywords.items():
|
| 356 |
+
if any(kw in message.lower() for kw in keywords):
|
| 357 |
+
if concern not in self.user_model["recurring_concerns"]:
|
| 358 |
+
self.user_model["recurring_concerns"].append(concern)
|
| 359 |
+
|
| 360 |
+
def generate_memory_context_string(self, contextual_memory):
|
| 361 |
+
"""Format retrieved memories into prompt context"""
|
| 362 |
+
context_parts = []
|
| 363 |
+
|
| 364 |
+
# Working memory (recent conversation)
|
| 365 |
+
if contextual_memory["working"]:
|
| 366 |
+
recent = "\n".join([
|
| 367 |
+
f"User: {m['user']}\nKenko: {m['bot']}"
|
| 368 |
+
for m in contextual_memory["working"]
|
| 369 |
+
])
|
| 370 |
+
context_parts.append(f"### Recent Conversation:\n{recent}")
|
| 371 |
+
|
| 372 |
+
# Important short-term memories
|
| 373 |
+
if contextual_memory["short_term"]:
|
| 374 |
+
important = "\n".join([m['text'] for m in contextual_memory["short_term"]])
|
| 375 |
+
context_parts.append(f"### Important Recent Context:\n{important}")
|
| 376 |
+
|
| 377 |
+
# Long-term thematic memories
|
| 378 |
+
if contextual_memory["long_term"]:
|
| 379 |
+
longterm = "\n".join([m['text'] for m in contextual_memory["long_term"]])
|
| 380 |
+
context_parts.append(f"### Related Past Discussions:\n{longterm}")
|
| 381 |
+
|
| 382 |
+
# Emotional patterns
|
| 383 |
+
if contextual_memory["emotional"]:
|
| 384 |
+
emotional = ", ".join(contextual_memory["emotional"][:3])
|
| 385 |
+
context_parts.append(f"### Emotional Pattern: Previously triggered by: {emotional}")
|
| 386 |
+
|
| 387 |
+
# Conversation themes
|
| 388 |
+
if contextual_memory["themes"]:
|
| 389 |
+
themes = ", ".join(contextual_memory["themes"])
|
| 390 |
+
context_parts.append(f"### Session Themes: {themes}")
|
| 391 |
+
|
| 392 |
+
# User model insights
|
| 393 |
+
if self.user_model["recurring_concerns"]:
|
| 394 |
+
concerns = ", ".join(self.user_model["recurring_concerns"])
|
| 395 |
+
context_parts.append(f"### Recurring Concerns: {concerns}")
|
| 396 |
+
|
| 397 |
+
return "\n\n".join(context_parts)
|
| 398 |
+
|
| 399 |
+
def reset(self):
|
| 400 |
+
"""Reset all memory tiers"""
|
| 401 |
+
self.working_memory = []
|
| 402 |
+
self.short_term_memory = []
|
| 403 |
+
self.semantic_clusters = defaultdict(list)
|
| 404 |
+
self.emotional_memory = {
|
| 405 |
+
"emotion_transitions": [],
|
| 406 |
+
"trigger_patterns": defaultdict(list),
|
| 407 |
+
"coping_effectiveness": {}
|
| 408 |
+
}
|
| 409 |
+
self.conversation_themes = []
|
| 410 |
+
self.user_model = {
|
| 411 |
+
"communication_style": None,
|
| 412 |
+
"recurring_concerns": [],
|
| 413 |
+
"progress_indicators": [],
|
| 414 |
+
"relational_patterns": []
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
# Initialize the advanced memory system
|
| 419 |
+
print("π Initializing Advanced Memory System...")
|
| 420 |
+
advanced_memory = AdvancedMemorySystem(embedding_model, global_vector_store)
|
| 421 |
+
print("β
Advanced Memory System initialized!")
|
| 422 |
+
|
| 423 |
+
# Track previous emotion for transition analysis
|
| 424 |
+
previous_emotion = "neutral"
|
| 425 |
+
|
| 426 |
def update_emotion_status():
|
|
|
|
| 427 |
if current_emotion_state["last_update"] is None:
|
| 428 |
return "*Waiting for emotion data...*"
|
| 429 |
|
|
|
|
| 435 |
confidence = current_emotion_state["confidence"]
|
| 436 |
return f"**Current Emotion:** {dominant.capitalize()} ({confidence:.1f}% confidence)\n*Last updated: {int(elapsed)}s ago*"
|
| 437 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
def analyze_emotion(image):
|
|
|
|
| 439 |
global current_emotion_state
|
| 440 |
|
| 441 |
try:
|
|
|
|
| 456 |
emotions = result['emotion']
|
| 457 |
dominant = result['dominant_emotion']
|
| 458 |
|
|
|
|
| 459 |
current_emotion_state = {
|
| 460 |
"dominant": dominant,
|
| 461 |
"confidence": emotions[dominant],
|
|
|
|
| 463 |
"last_update": time.time()
|
| 464 |
}
|
| 465 |
|
|
|
|
| 466 |
output = {}
|
| 467 |
for emotion, score in sorted(emotions.items(), key=lambda x: x[1], reverse=True):
|
| 468 |
+
output[emotion.capitalize()] = score
|
| 469 |
|
| 470 |
return output
|
| 471 |
|
|
|
|
| 478 |
if current_emotion_state["last_update"] is None:
|
| 479 |
return ""
|
| 480 |
|
|
|
|
| 481 |
if time.time() - current_emotion_state["last_update"] > 60:
|
| 482 |
return ""
|
| 483 |
|
|
|
|
| 490 |
def chat_with_kenko(message, history):
|
| 491 |
"""Chat function for Gradio interface with emotion awareness"""
|
| 492 |
|
|
|
|
| 493 |
conversation = ""
|
| 494 |
for user_msg, bot_msg in history:
|
| 495 |
conversation += f"User: {user_msg}\nKenko: {bot_msg}\n\n"
|
| 496 |
|
|
|
|
| 497 |
emotion_context = get_emotion_context()
|
| 498 |
|
|
|
|
| 499 |
threat_context = get_threat_context()
|
| 500 |
|
|
|
|
| 501 |
prompt = f"""### Instruction:
|
| 502 |
You are Kenko, a compassionate mental health therapist. Provide empathetic, helpful, and professional responses to support the user's mental wellbeing.
|
| 503 |
{emotion_context}{threat_context}
|
|
|
|
| 504 |
{conversation}User: {message}
|
|
|
|
| 505 |
### Response:
|
| 506 |
"""
|
| 507 |
|
|
|
|
| 508 |
try:
|
| 509 |
response = pipe(prompt)[0]['generated_text']
|
| 510 |
return response.strip()
|
|
|
|
| 517 |
|
| 518 |
print(f"[TTS] Generating speech for {len(text)} chars: '{text[:50]}...'")
|
| 519 |
|
|
|
|
| 520 |
inputs = tts_processor(text=text, return_tensors="pt", padding=True)
|
| 521 |
inputs = {k: v.to(tts_device) for k, v in inputs.items()}
|
| 522 |
|
| 523 |
print(f"[TTS] Inputs prepared, generating audio codes...")
|
| 524 |
|
|
|
|
| 525 |
with torch.no_grad():
|
| 526 |
generated_ids = tts_model.generate(**inputs, max_length=2500)
|
| 527 |
|
| 528 |
print(f"[TTS] Audio codes generated, shape: {generated_ids.shape}")
|
| 529 |
print(f"[TTS] Decoding codes to waveform...")
|
| 530 |
|
|
|
|
| 531 |
audio_values = tts_processor.batch_decode(generated_ids, return_tensors="pt")
|
| 532 |
|
|
|
|
| 533 |
if isinstance(audio_values, dict) and 'audio_values' in audio_values:
|
| 534 |
audio_arr = audio_values['audio_values'][0].cpu().numpy()
|
| 535 |
elif isinstance(audio_values, torch.Tensor):
|
|
|
|
| 539 |
else:
|
| 540 |
audio_arr = np.array(audio_values).squeeze()
|
| 541 |
|
|
|
|
| 542 |
audio_arr = audio_arr.astype(np.float32)
|
| 543 |
|
|
|
|
| 544 |
sample_rate = 44100
|
| 545 |
|
| 546 |
+
print(f"[TTS] Audio decoded: {len(audio_arr)} samples at {sample_rate}Hz = {len(audio_arr)/sample_rate:.2f} seconds")
|
| 547 |
|
| 548 |
if len(audio_arr) == 0:
|
| 549 |
+
print("Decoded audio is empty!")
|
| 550 |
return None
|
| 551 |
|
| 552 |
return (sample_rate, audio_arr)
|
| 553 |
|
| 554 |
except Exception as e:
|
| 555 |
+
print(f"TTS generation error: {str(e)}")
|
| 556 |
import traceback
|
| 557 |
traceback.print_exc()
|
| 558 |
return None
|
| 559 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
# Custom CSS for a calming interface
|
| 561 |
css = """
|
| 562 |
.gradio-container {
|
|
|
|
| 570 |
}
|
| 571 |
"""
|
| 572 |
|
|
|
|
| 573 |
with gr.Blocks(
|
| 574 |
title="Kenko - Mental Health Assistant",
|
| 575 |
theme=gr.themes.Soft(),
|
|
|
|
| 577 |
) as demo:
|
| 578 |
|
| 579 |
gr.Markdown("""
|
| 580 |
+
# π Kenko - Your Emotion-Aware Mental Health Assistant
|
|
|
|
| 581 |
Welcome! I'm Kenko, an AI mental health therapist enhanced with real-time emotion detection.
|
| 582 |
Allow webcam access to enable emotion-aware responses that adapt to how you're feeling.
|
|
|
|
| 583 |
*Please remember: I'm an AI assistant and cannot replace professional mental health care. In crisis situations, please contact emergency services or a mental health professional.*
|
| 584 |
""")
|
| 585 |
|
| 586 |
with gr.Row():
|
|
|
|
| 587 |
with gr.Column(scale=2):
|
| 588 |
chatbot = gr.Chatbot(
|
| 589 |
height=500,
|
|
|
|
| 613 |
clear_btn = gr.Button("ποΈ Clear Chat", scale=1, variant="secondary")
|
| 614 |
examples_btn = gr.Button("π‘ Example Topics", scale=1, variant="secondary")
|
| 615 |
|
|
|
|
| 616 |
with gr.Column(scale=1):
|
| 617 |
gr.Markdown("### πΈ Emotion Detection")
|
| 618 |
gr.Markdown("*Your emotional state helps me provide more personalized support*")
|
|
|
|
| 631 |
|
| 632 |
emotion_status = gr.Markdown("*Waiting for emotion data...*")
|
| 633 |
|
| 634 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
with gr.Row(visible=False) as examples_row:
|
| 636 |
gr.Examples(
|
| 637 |
examples=[
|
|
|
|
| 657 |
- Relationship and communication advice
|
| 658 |
- Mindfulness and self-care suggestions
|
| 659 |
- Building healthy habits and routines
|
|
|
|
| 660 |
**Emotion Detection Feature:**
|
| 661 |
- Real-time facial emotion analysis
|
| 662 |
- Adapts responses based on your current emotional state
|
| 663 |
- Updates automatically every 30 seconds
|
| 664 |
- Completely optional - works without webcam too
|
|
|
|
| 665 |
**Important Notes:**
|
| 666 |
- I'm an AI trained to provide mental health support
|
| 667 |
- For immediate crisis support, contact emergency services (911) or crisis hotlines
|
| 668 |
- Consider professional therapy for ongoing mental health needs
|
| 669 |
- I don't diagnose conditions or prescribe medications
|
|
|
|
| 670 |
**Privacy:** Your conversations and emotion data are not stored or shared.
|
| 671 |
""")
|
| 672 |
@spaces.GPU
|
|
|
|
| 695 |
return gr.Row(visible=True)
|
| 696 |
|
| 697 |
|
| 698 |
+
|
| 699 |
submit = msg.submit(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot, audio_output])
|
| 700 |
send = send_btn.click(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot, audio_output])
|
| 701 |
clear_btn.click(lambda: [], None, outputs=[chatbot, audio_output])
|
| 702 |
examples_btn.click(toggle_examples, outputs=examples_row)
|
| 703 |
|
| 704 |
+
|
| 705 |
webcam_input.stream(
|
| 706 |
analyze_emotion,
|
| 707 |
inputs=webcam_input,
|
| 708 |
outputs=emotion_output,
|
| 709 |
+
stream_every=1,
|
| 710 |
+
time_limit=60
|
| 711 |
)
|
| 712 |
|
| 713 |
+
timer = gr.Timer(value=5)
|
| 714 |
+
|
| 715 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 716 |
timer.tick(
|
| 717 |
+
fn=lambda: (update_emotion_status()),
|
| 718 |
+
outputs=[emotion_status]
|
| 719 |
)
|
| 720 |
|
| 721 |
|