IniNLP247 commited on
Commit
298147b
·
verified ·
1 Parent(s): 2ce55f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -43
app.py CHANGED
@@ -6,7 +6,9 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
  import torch
7
  from deepface import DeepFace
8
  import time
9
- from transformers import AutoProcessor, DiaForConditionalGeneration
 
 
10
  from sentence_transformers import SentenceTransformer
11
  import numpy as np
12
  import chromadb
@@ -44,10 +46,7 @@ print("Model loaded successfully!")
44
 
45
 
46
  print("Loading Dia TTS Model...")
47
- tts_device = "cuda:0" if torch.cuda.is_available() else "cpu"
48
- tts_model = "nari-labs/Dia-1.6B-0626"
49
- tts_processor = AutoProcessor.from_pretrained(tts_model)
50
- tts_model = DiaForConditionalGeneration.from_pretrained(tts_model, torch_dtype=torch.float16).to(tts_device)
51
  print("Dia TTS Model loaded successfully!")
52
 
53
  print("Initializing Memory Components...")
@@ -479,50 +478,31 @@ You are Kenko, a compassionate mental health therapist. Provide empathetic, help
479
  return f"I'm sorry, I'm having trouble processing your message right now. Error: {str(e)}"
480
 
481
  def generate_tts(text):
482
- try:
483
- text = text[:600]
484
-
485
- print(f"[TTS] Generating speech for {len(text)} chars: '{text[:50]}...'")
486
-
487
- inputs = tts_processor(text=text, return_tensors="pt", padding=True)
488
- inputs = {k: v.to(tts_device) for k, v in inputs.items()}
489
-
490
- print(f"[TTS] Inputs prepared, generating audio codes...")
491
-
492
- with torch.no_grad():
493
- generated_ids = tts_model.generate(**inputs, max_length=2500)
494
 
495
- print(f"[TTS] Audio codes generated, shape: {generated_ids.shape}")
496
- print(f"[TTS] Decoding codes to waveform...")
 
 
497
 
498
- audio_values = tts_processor.batch_decode(generated_ids, return_tensors="pt")
 
 
499
 
500
- if isinstance(audio_values, dict) and 'audio_values' in audio_values:
501
- audio_arr = audio_values['audio_values'][0].cpu().numpy()
502
- elif isinstance(audio_values, torch.Tensor):
503
- audio_arr = audio_values[0].cpu().numpy()
504
- elif isinstance(audio_values, list):
505
- audio_arr = np.array(audio_values[0])
506
- else:
507
- audio_arr = np.array(audio_values).squeeze()
508
-
509
- audio_arr = audio_arr.astype(np.float32)
510
-
511
- sample_rate = 44100
512
 
513
- print(f"[TTS] Audio decoded: {len(audio_arr)} samples at {sample_rate}Hz = {len(audio_arr)/sample_rate:.2f} seconds")
 
 
514
 
515
- if len(audio_arr) == 0:
516
- print("Decoded audio is empty!")
517
- return None
518
 
519
- return (sample_rate, audio_arr)
520
-
521
- except Exception as e:
522
- print(f"TTS generation error: {str(e)}")
523
- import traceback
524
- traceback.print_exc()
525
- return None
526
 
527
  css = """
528
  .gradio-container {
 
6
  import torch
7
  from deepface import DeepFace
8
  import time
9
+ from kokoro import KPipeline
10
+ from IPython.display import display, Audio
11
+ import soundfile as sf
12
  from sentence_transformers import SentenceTransformer
13
  import numpy as np
14
  import chromadb
 
46
 
47
 
48
  print("Loading Dia TTS Model...")
49
+ tts_pipeline = KPipeline(lang_code='b')
 
 
 
50
  print("Dia TTS Model loaded successfully!")
51
 
52
  print("Initializing Memory Components...")
 
478
  return f"I'm sorry, I'm having trouble processing your message right now. Error: {str(e)}"
479
 
480
  def generate_tts(text):
481
+ try:
482
+ text = text[:600]
 
 
 
 
 
 
 
 
 
 
483
 
484
+ generator = tts_pipeline(
485
+ text, voice='af_heart',
486
+ speed=1, split_pattern=r'\n+'
487
+ )
488
 
489
+ audio_chunks = []
490
+ for gs, ps, audio in generator:
491
+ audio_chunks.append(audio)
492
 
493
+ if not audio_chunks:
494
+ print("TTS generation failed")
495
+ return None
 
 
 
 
 
 
 
 
 
496
 
497
+ audio_array = np.concatenate(audio_chunks, axis=0)
498
+ audio_array = audio_array.astype(np.float32)
499
+ sample_rate = 24000
500
 
501
+ return (sample_rate, audio_array)
 
 
502
 
503
+ except Exception as e:
504
+ print(f"TTS generation error: {str(e)}")
505
+ return None
 
 
 
 
506
 
507
  css = """
508
  .gradio-container {