Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,9 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
|
| 6 |
import torch
|
| 7 |
from deepface import DeepFace
|
| 8 |
import time
|
| 9 |
-
from
|
|
|
|
|
|
|
| 10 |
from sentence_transformers import SentenceTransformer
|
| 11 |
import numpy as np
|
| 12 |
import chromadb
|
|
@@ -44,10 +46,7 @@ print("Model loaded successfully!")
|
|
| 44 |
|
| 45 |
|
| 46 |
print("Loading Dia TTS Model...")
|
| 47 |
-
|
| 48 |
-
tts_model = "nari-labs/Dia-1.6B-0626"
|
| 49 |
-
tts_processor = AutoProcessor.from_pretrained(tts_model)
|
| 50 |
-
tts_model = DiaForConditionalGeneration.from_pretrained(tts_model, torch_dtype=torch.float16).to(tts_device)
|
| 51 |
print("Dia TTS Model loaded successfully!")
|
| 52 |
|
| 53 |
print("Initializing Memory Components...")
|
|
@@ -479,50 +478,31 @@ You are Kenko, a compassionate mental health therapist. Provide empathetic, help
|
|
| 479 |
return f"I'm sorry, I'm having trouble processing your message right now. Error: {str(e)}"
|
| 480 |
|
| 481 |
def generate_tts(text):
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
print(f"[TTS] Generating speech for {len(text)} chars: '{text[:50]}...'")
|
| 486 |
-
|
| 487 |
-
inputs = tts_processor(text=text, return_tensors="pt", padding=True)
|
| 488 |
-
inputs = {k: v.to(tts_device) for k, v in inputs.items()}
|
| 489 |
-
|
| 490 |
-
print(f"[TTS] Inputs prepared, generating audio codes...")
|
| 491 |
-
|
| 492 |
-
with torch.no_grad():
|
| 493 |
-
generated_ids = tts_model.generate(**inputs, max_length=2500)
|
| 494 |
|
| 495 |
-
|
| 496 |
-
|
|
|
|
|
|
|
| 497 |
|
| 498 |
-
|
|
|
|
|
|
|
| 499 |
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
audio_arr = audio_values[0].cpu().numpy()
|
| 504 |
-
elif isinstance(audio_values, list):
|
| 505 |
-
audio_arr = np.array(audio_values[0])
|
| 506 |
-
else:
|
| 507 |
-
audio_arr = np.array(audio_values).squeeze()
|
| 508 |
-
|
| 509 |
-
audio_arr = audio_arr.astype(np.float32)
|
| 510 |
-
|
| 511 |
-
sample_rate = 44100
|
| 512 |
|
| 513 |
-
|
|
|
|
|
|
|
| 514 |
|
| 515 |
-
|
| 516 |
-
print("Decoded audio is empty!")
|
| 517 |
-
return None
|
| 518 |
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
print(f"TTS generation error: {str(e)}")
|
| 523 |
-
import traceback
|
| 524 |
-
traceback.print_exc()
|
| 525 |
-
return None
|
| 526 |
|
| 527 |
css = """
|
| 528 |
.gradio-container {
|
|
|
|
| 6 |
import torch
|
| 7 |
from deepface import DeepFace
|
| 8 |
import time
|
| 9 |
+
from kokoro import KPipeline
|
| 10 |
+
from IPython.display import display, Audio
|
| 11 |
+
import soundfile as sf
|
| 12 |
from sentence_transformers import SentenceTransformer
|
| 13 |
import numpy as np
|
| 14 |
import chromadb
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
print("Loading Dia TTS Model...")
|
| 49 |
+
tts_pipeline = KPipeline(lang_code='b')
|
|
|
|
|
|
|
|
|
|
| 50 |
print("Dia TTS Model loaded successfully!")
|
| 51 |
|
| 52 |
print("Initializing Memory Components...")
|
|
|
|
| 478 |
return f"I'm sorry, I'm having trouble processing your message right now. Error: {str(e)}"
|
| 479 |
|
| 480 |
def generate_tts(text):
|
| 481 |
+
try:
|
| 482 |
+
text = text[:600]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
|
| 484 |
+
generator = tts_pipeline(
|
| 485 |
+
text, voice='af_heart',
|
| 486 |
+
speed=1, split_pattern=r'\n+'
|
| 487 |
+
)
|
| 488 |
|
| 489 |
+
audio_chunks = []
|
| 490 |
+
for gs, ps, audio in generator:
|
| 491 |
+
audio_chunks.append(audio)
|
| 492 |
|
| 493 |
+
if not audio_chunks:
|
| 494 |
+
print("TTS generation failed")
|
| 495 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
+
audio_array = np.concatenate(audio_chunks, axis=0)
|
| 498 |
+
audio_array = audio_array.astype(np.float32)
|
| 499 |
+
sample_rate = 24000
|
| 500 |
|
| 501 |
+
return (sample_rate, audio_array)
|
|
|
|
|
|
|
| 502 |
|
| 503 |
+
except Exception as e:
|
| 504 |
+
print(f"TTS generation error: {str(e)}")
|
| 505 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
css = """
|
| 508 |
.gradio-container {
|