IniNLP247 commited on
Commit
66954e0
Β·
verified Β·
1 Parent(s): c253a21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -41
app.py CHANGED
@@ -1,15 +1,19 @@
1
- #INFERENCE NLP+EMOTION DETECTION CV+TTS
2
  import spaces
3
 
4
  import gradio as gr
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
  import torch
7
  from deepface import DeepFace
8
  import threading
9
  import time
10
- from parler_tts import ParlerTTSForConditionalGeneration
11
- import soundfile as sf
12
  import numpy as np
 
 
 
 
 
13
 
14
  # Model setup
15
  model_name = "IniNLP247/Kenko-mental-health-llama-3-model"
@@ -20,6 +24,7 @@ print("πŸ”„ Loading Kenko Mental Health Model...")
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  if tokenizer.pad_token is None:
22
  tokenizer.pad_token = tokenizer.eos_token
 
23
 
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_name,
@@ -43,12 +48,46 @@ pipe = pipeline(
43
  print("βœ… Model loaded successfully!")
44
 
45
  #Loading of TTS
46
- print("Loading Parler TTS Model...")
47
  tts_device = "cuda:0" if torch.cuda.is_available() else "cpu"
48
- tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1", torch_dtype=torch.float16).to(tts_device)
49
- tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
50
- print("βœ… Parler TTS Model loaded successfully!")
 
 
 
 
 
 
 
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  # Global variable to store current emotion state
54
  current_emotion_state = {
@@ -125,10 +164,13 @@ def chat_with_kenko(message, history):
125
  # Get emotion context
126
  emotion_context = get_emotion_context()
127
 
 
 
 
128
  # Create prompt in instruction format with emotion awareness
129
  prompt = f"""### Instruction:
130
  You are Kenko, a compassionate mental health therapist. Provide empathetic, helpful, and professional responses to support the user's mental wellbeing.
131
- {emotion_context}
132
 
133
  {conversation}User: {message}
134
 
@@ -144,35 +186,49 @@ You are Kenko, a compassionate mental health therapist. Provide empathetic, help
144
 
145
  def generate_tts(text):
146
  try:
147
- # Limit text severely for testing
148
- text = text[:200] # Even shorter for testing
149
 
150
- print(f"[TTS] Starting generation for {len(text)} chars: '{text[:50]}...'")
151
 
152
- description = "A calm, empathetic voice speaking at a moderate pace."
 
 
153
 
154
- input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to(tts_device)
155
- prompt_input_ids = tts_tokenizer(text, return_tensors="pt").input_ids.to(tts_device)
156
 
157
- print(f"[TTS] Tokenization complete. Generating audio...")
 
 
158
 
159
- # Use proper generation parameters for Parler TTS
160
- generation = tts_model.generate(
161
- input_ids=input_ids,
162
- prompt_input_ids=prompt_input_ids,
163
- do_sample=True,
164
- temperature=1.0,
165
- min_new_tokens=10,
166
- max_new_tokens=500 # Use max_new_tokens instead of max_length
167
- )
 
 
 
 
 
 
 
 
 
168
 
169
- print(f"[TTS] Generation complete. Processing audio...")
 
170
 
171
- audio_arr = generation.cpu().numpy().squeeze()
172
 
173
- print(f"[TTS] Audio array shape: {audio_arr.shape}")
 
 
174
 
175
- return (tts_model.config.sampling_rate, audio_arr)
176
 
177
  except Exception as e:
178
  print(f"❌ TTS generation error: {str(e)}")
@@ -180,8 +236,66 @@ def generate_tts(text):
180
  traceback.print_exc()
181
  return None
182
 
183
- print(f"TTS Model Device: {tts_model.device}")
184
- print(f"TTS Device Variable: {tts_device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  # Custom CSS for a calming interface
187
  css = """
@@ -262,6 +376,16 @@ with gr.Blocks(
262
 
263
  emotion_status = gr.Markdown("*Waiting for emotion data...*")
264
 
 
 
 
 
 
 
 
 
 
 
265
  # Example prompts
266
  with gr.Row(visible=False) as examples_row:
267
  gr.Examples(
@@ -303,7 +427,6 @@ with gr.Blocks(
303
 
304
  **Privacy:** Your conversations and emotion data are not stored or shared.
305
  """)
306
-
307
  @spaces.GPU
308
  def respond(message, chat_history):
309
  if not message.strip():
@@ -325,7 +448,6 @@ with gr.Blocks(
325
  print(f"TOTAL TIME: {time.time() - start:.2f}s")
326
 
327
  return "", chat_history, audio
328
- return "", chat_history, audio
329
 
330
  def toggle_examples():
331
  return gr.Row(visible=True)
@@ -343,6 +465,21 @@ with gr.Blocks(
343
  confidence = current_emotion_state["confidence"]
344
  return f"**Current Emotion:** {dominant.capitalize()} ({confidence:.1f}% confidence)\n*Last updated: {int(elapsed)}s ago*"
345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  # Event handlers
347
  submit = msg.submit(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot, audio_output])
348
  send = send_btn.click(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot, audio_output])
@@ -351,18 +488,40 @@ with gr.Blocks(
351
 
352
  # Emotion detection with streaming (analyzes continuously)
353
  webcam_input.stream(
354
- analyze_emotion,
355
- inputs=webcam_input,
356
- outputs=emotion_output,
357
- time_limit=30, # Analyze every 30 seconds
358
- stream_every=30 # Update interval
359
- )
360
 
361
  timer = gr.Timer(value=5) # Update every 5 seconds
362
  timer.tick(
363
  fn=update_emotion_status,
364
  outputs=emotion_status
365
  )
366
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  if __name__ == "__main__":
368
- demo.launch()
 
 
 
 
 
 
 
1
+ #INFERENCE NLP+EMOTION DETECTION CV+TTS+THREAT DETECTION CV
2
  import spaces
3
 
4
  import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
6
  import torch
7
  from deepface import DeepFace
8
  import threading
9
  import time
10
+ from transformers import AutoProcessor, DiaForConditionalGeneration
 
11
  import numpy as np
12
+ import supervision as sv
13
+ import requests
14
+ from PIL import Image
15
+ import os
16
+ from rfdetr import RFDETRNano
17
 
18
  # Model setup
19
  model_name = "IniNLP247/Kenko-mental-health-llama-3-model"
 
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
25
  if tokenizer.pad_token is None:
26
  tokenizer.pad_token = tokenizer.eos_token
27
+ )
28
 
29
  model = AutoModelForCausalLM.from_pretrained(
30
  model_name,
 
48
  print("βœ… Model loaded successfully!")
49
 
50
  #Loading of TTS
51
+ print("Loading Dia TTS Model...")
52
  tts_device = "cuda:0" if torch.cuda.is_available() else "cpu"
53
+ tts_model = "nari-labs/Dia-1.6B-0626"
54
+ tts_processor = AutoProcessor.from_pretrained(tts_model)
55
+ tts_model = DiaForConditionalGeneration.from_pretrained(tts_model, torch_dtype=torch.float16).to(tts_device)
56
+ print("βœ… Dia TTS Model loaded successfully!")
57
+
58
+ THREAT_CLASSES = {
59
+ 1: "Gun",
60
+ 2: "Explosive",
61
+ 3: "Grenade",
62
+ 4: "Knife"
63
+ }
64
 
65
+ #Loading Threat Detection Model
66
+ print("Loading Threat Detection Model...")
67
+ threat_weights_url = "https://huggingface.co/Subh775/Threat-Detection-RF-DETR/resolve/main/checkpoint_best_total.pth"
68
+ threat_weights_filename = "checkpoint_best_total.pth"
69
+
70
+ # Download weights if not already present
71
+ if not os.path.exists(threat_weights_filename):
72
+ print(f"Downloading weights from {threat_weights_url}")
73
+ response = requests.get(threat_weights_url, stream=True)
74
+ response.raise_for_status()
75
+ with open(threat_weights_filename, 'wb') as f:
76
+ for chunk in response.iter_content(chunk_size=8192):
77
+ f.write(chunk)
78
+ print("Download complete.")
79
+
80
+ threat_model = RFDETRNano(resolution=640, pretrain_weights=threat_weights_filename)
81
+ #threat_model.optimize_for_inference()
82
+
83
+ print("βœ… Threat Detection Model loaded successfully!")
84
+
85
+ #Global Variables For Threat Detection
86
+ current_thtreat_state = {
87
+ "threat_detected": [],
88
+ "threat_count": 0,
89
+ "last_update": None
90
+ }
91
 
92
  # Global variable to store current emotion state
93
  current_emotion_state = {
 
164
  # Get emotion context
165
  emotion_context = get_emotion_context()
166
 
167
+ # Get threat context
168
+ threat_context = get_threat_context()
169
+
170
  # Create prompt in instruction format with emotion awareness
171
  prompt = f"""### Instruction:
172
  You are Kenko, a compassionate mental health therapist. Provide empathetic, helpful, and professional responses to support the user's mental wellbeing.
173
+ {emotion_context}{threat_context}
174
 
175
  {conversation}User: {message}
176
 
 
186
 
187
  def generate_tts(text):
188
  try:
189
+ text = text[:600]
 
190
 
191
+ print(f"[TTS] Generating speech for {len(text)} chars: '{text[:50]}...'")
192
 
193
+ # Prepare inputs for Dia TTS
194
+ inputs = tts_processor(text=text, return_tensors="pt", padding=True)
195
+ inputs = {k: v.to(tts_device) for k, v in inputs.items()}
196
 
197
+ print(f"[TTS] Inputs prepared, generating audio codes...")
 
198
 
199
+ # Generate audio codes
200
+ with torch.no_grad():
201
+ generated_ids = tts_model.generate(**inputs, max_length=5000)
202
 
203
+ print(f"[TTS] Audio codes generated, shape: {generated_ids.shape}")
204
+ print(f"[TTS] Decoding codes to waveform...")
205
+
206
+ # Decode the audio codes to waveform using the processor's batch_decode
207
+ audio_values = tts_processor.batch_decode(generated_ids, return_tensors="pt")
208
+
209
+ # Extract the audio waveform
210
+ if isinstance(audio_values, dict) and 'audio_values' in audio_values:
211
+ audio_arr = audio_values['audio_values'][0].cpu().numpy()
212
+ elif isinstance(audio_values, torch.Tensor):
213
+ audio_arr = audio_values[0].cpu().numpy()
214
+ elif isinstance(audio_values, list):
215
+ audio_arr = np.array(audio_values[0])
216
+ else:
217
+ audio_arr = np.array(audio_values).squeeze()
218
+
219
+ # Ensure float32
220
+ audio_arr = audio_arr.astype(np.float32)
221
 
222
+ # Dia uses 44.1kHz sample rate
223
+ sample_rate = 44100
224
 
225
+ print(f"βœ… [TTS] Audio decoded: {len(audio_arr)} samples at {sample_rate}Hz = {len(audio_arr)/sample_rate:.2f} seconds")
226
 
227
+ if len(audio_arr) == 0:
228
+ print("❌ Decoded audio is empty!")
229
+ return None
230
 
231
+ return (sample_rate, audio_arr)
232
 
233
  except Exception as e:
234
  print(f"❌ TTS generation error: {str(e)}")
 
236
  traceback.print_exc()
237
  return None
238
 
239
+ def threat_detection():
240
+ """Threat detection function for webcam"""
241
+ global current_threat_state
242
+
243
+ try:
244
+ if image is None:
245
+ return {}
246
+
247
+ #Run Threat Detection
248
+ detections = threat_model.predict(image, threshold=0)
249
+
250
+ #Parse detections
251
+ threat_found = []
252
+ if detection is not None and len(detections) > 0:
253
+ #Extract class IDs and confidence
254
+ for detection in detections:
255
+ class_id = int(detection.class_id) if hasattr(detection, 'class_id') else None
256
+ confidence = float(detection.confidence) if hasattr(detection, 'confidence') else 0.0
257
+
258
+ if class_id in THREAT_CLASSES:
259
+ threat_name = THREAT_CLASSES[class_id]
260
+ threat_found.append({"type": threat_name, "confidence": confidence})
261
+
262
+ #Update global threat state
263
+ current_threat_state = {
264
+ "threat_detected": threat_found,
265
+ "threat_count": len(threat_found),
266
+ "last_update": time.time()
267
+ }
268
+
269
+ #Format for display
270
+ if threats_found:
271
+ output = {}
272
+ for threat in threats_found:
273
+ output[threat["type"]] = threat["confidence"] * 100
274
+ return output
275
+ else:
276
+ return {"No threats detected, all clear": 100.0}
277
+
278
+ except Exception as e:
279
+ print(f"Threat detection error: {str(e)}")
280
+ return {}
281
+
282
+
283
+ def get_threat_context():
284
+ """Get current threat as context string for the model"""
285
+ if current_threat_state["last_update"] is None:
286
+ return ""
287
+
288
+ #Check if threat data is recent (within last 60 seconds)
289
+ if time.time() - current_threat_state["last_update"] > 60:
290
+ return ""
291
+
292
+ threats = current_threat_state["threat_detected"]
293
+
294
+ if threats:
295
+ threat_list = ", ".join([f"{t['type']} ({t['confidence']*100:.1f}% confidence)" for t in threats])
296
+ return f"\n[User currently holds a potential threat: {threat_list}]"
297
+
298
+ return ""
299
 
300
  # Custom CSS for a calming interface
301
  css = """
 
376
 
377
  emotion_status = gr.Markdown("*Waiting for emotion data...*")
378
 
379
+ #Threat detection output
380
+ gr.Markdown("### Threat Detection")
381
+ threat_output = gr.Label(
382
+ num_top_classes=4,
383
+ label="Detected Threats"
384
+ )
385
+
386
+ threat_status = gr.Markdown("*Monitoring for threats...")
387
+
388
+
389
  # Example prompts
390
  with gr.Row(visible=False) as examples_row:
391
  gr.Examples(
 
427
 
428
  **Privacy:** Your conversations and emotion data are not stored or shared.
429
  """)
 
430
  @spaces.GPU
431
  def respond(message, chat_history):
432
  if not message.strip():
 
448
  print(f"TOTAL TIME: {time.time() - start:.2f}s")
449
 
450
  return "", chat_history, audio
 
451
 
452
  def toggle_examples():
453
  return gr.Row(visible=True)
 
465
  confidence = current_emotion_state["confidence"]
466
  return f"**Current Emotion:** {dominant.capitalize()} ({confidence:.1f}% confidence)\n*Last updated: {int(elapsed)}s ago*"
467
 
468
+ def update_threat_status():
469
+ """Update threat status text"""
470
+ if current_threat_state["last_update"] is None:
471
+ return "*Monitoring for threats...*"
472
+
473
+ elapsed = time.time() - current_threat_state["last_update"]
474
+
475
+ threats = current_threat_state["threat_detected"] # Corrected variable name
476
+ if threats:
477
+ threat_list = ", ".join([t["type"] for t in threats])
478
+ return f"**⚠️ ALERT:** {threat_list} detected\n*Last updated: {int(elapsed)}s ago*"
479
+ else:
480
+ return f"**βœ… Safe:** No threats detected\n*Last updated: {int(elapsed)}s ago*"
481
+
482
+
483
  # Event handlers
484
  submit = msg.submit(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot, audio_output])
485
  send = send_btn.click(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot, audio_output])
 
488
 
489
  # Emotion detection with streaming (analyzes continuously)
490
  webcam_input.stream(
491
+ analyze_emotion,
492
+ inputs=webcam_input,
493
+ outputs=emotion_output,
494
+ stream_every=1, # Update every 1 second instead of 30
495
+ time_limit=60 # Keep processing for 60 seconds
496
+ )
497
 
498
  timer = gr.Timer(value=5) # Update every 5 seconds
499
  timer.tick(
500
  fn=update_emotion_status,
501
  outputs=emotion_status
502
  )
503
+
504
+ # Threat detection with streaming
505
+ webcam_input.stream(
506
+ threat_detection, # Corrected function name
507
+ inputs=webcam_input, # Corrected inputs
508
+ outputs=threat_output,
509
+ stream_every=2,
510
+ time_limit=60
511
+ )
512
+ # Add to timer tick
513
+ timer.tick(
514
+ fn=lambda: (update_emotion_status(), update_threat_status()),
515
+ outputs=[emotion_status, threat_status]
516
+ )
517
+
518
+
519
+
520
  if __name__ == "__main__":
521
+ print("πŸš€ Starting Kenko Mental Health Assistant with Emotion Detection...")
522
+ demo.launch(
523
+ server_name="0.0.0.0",
524
+ server_port=7890,
525
+ share=True,
526
+ show_error=True
527
+ )