Y Phung Nguyen commited on
Commit
d0e54ed
·
1 Parent(s): b61cc05

Upd maya configs

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. models.py +46 -8
  3. requirements.txt +1 -8
  4. ui.py +3 -2
  5. voice.py +142 -11
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  .env
2
  .setup.txt
 
3
  __pycache__/
 
1
  .env
2
  .setup.txt
3
+ maya.txt
4
  __pycache__/
models.py CHANGED
@@ -9,6 +9,14 @@ from logger import logger
9
  import config
10
  import spaces
11
 
 
 
 
 
 
 
 
 
12
  try:
13
  from TTS.api import TTS
14
  TTS_AVAILABLE = True
@@ -242,10 +250,12 @@ def move_model_to_gpu(model_name: str):
242
  return model
243
 
244
  def initialize_tts_model():
245
- """Initialize TTS model for text-to-speech"""
246
- if not TTS_AVAILABLE:
247
- logger.warning("TTS library not installed. TTS features will be disabled.")
 
248
  return None
 
249
  if config.global_tts_model is None:
250
  try:
251
  # Clear GPU cache before loading
@@ -253,17 +263,45 @@ def initialize_tts_model():
253
  torch.cuda.empty_cache()
254
  logger.debug("Cleared GPU cache before TTS model loading")
255
 
256
- logger.info("Initializing TTS model for voice generation...")
257
- config.global_tts_model = TTS(model_name=config.TTS_MODEL, progress_bar=False)
258
- logger.info("TTS model initialized successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  # Clear cache after loading
261
  if torch.cuda.is_available():
262
  torch.cuda.empty_cache()
263
  logger.debug("Cleared GPU cache after TTS model loading")
264
  except Exception as e:
265
- logger.warning(f"TTS model initialization failed: {e}")
266
- logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts")
 
 
267
  config.global_tts_model = None
268
  # Clear cache on error
269
  if torch.cuda.is_available():
 
9
  import config
10
  import spaces
11
 
12
+ try:
13
+ from snac import SNAC
14
+ SNAC_AVAILABLE = True
15
+ except ImportError:
16
+ SNAC_AVAILABLE = False
17
+ SNAC = None
18
+
19
+ # For backward compatibility, check TTS library too (but we use Maya1 directly)
20
  try:
21
  from TTS.api import TTS
22
  TTS_AVAILABLE = True
 
250
  return model
251
 
252
  def initialize_tts_model():
253
+ """Initialize Maya1 TTS model for text-to-speech using transformers and SNAC"""
254
+ if not SNAC_AVAILABLE:
255
+ logger.warning("SNAC library not installed. Maya1 TTS features will be disabled.")
256
+ logger.warning("Install with: pip install snac")
257
  return None
258
+
259
  if config.global_tts_model is None:
260
  try:
261
  # Clear GPU cache before loading
 
263
  torch.cuda.empty_cache()
264
  logger.debug("Cleared GPU cache before TTS model loading")
265
 
266
+ logger.info("Initializing Maya1 TTS model with Transformers...")
267
+
268
+ # Load Maya1 model and tokenizer
269
+ model = AutoModelForCausalLM.from_pretrained(
270
+ config.TTS_MODEL,
271
+ torch_dtype=torch.bfloat16,
272
+ device_map="auto",
273
+ trust_remote_code=True,
274
+ token=config.HF_TOKEN
275
+ )
276
+ tokenizer = AutoTokenizer.from_pretrained(
277
+ config.TTS_MODEL,
278
+ trust_remote_code=True,
279
+ token=config.HF_TOKEN
280
+ )
281
+
282
+ logger.info("Loading SNAC decoder...")
283
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
284
+ if torch.cuda.is_available():
285
+ snac_model = snac_model.to("cuda")
286
+
287
+ # Store as a dictionary with model, tokenizer, and snac_model
288
+ config.global_tts_model = {
289
+ "model": model,
290
+ "tokenizer": tokenizer,
291
+ "snac_model": snac_model
292
+ }
293
+
294
+ logger.info("Maya1 TTS model initialized successfully")
295
 
296
  # Clear cache after loading
297
  if torch.cuda.is_available():
298
  torch.cuda.empty_cache()
299
  logger.debug("Cleared GPU cache after TTS model loading")
300
  except Exception as e:
301
+ logger.warning(f"Maya1 TTS model initialization failed: {e}")
302
+ import traceback
303
+ logger.warning(f"TTS initialization traceback: {traceback.format_exc()}")
304
+ logger.warning("TTS features will be disabled. Install dependencies: pip install snac transformers")
305
  config.global_tts_model = None
306
  # Clear cache on error
307
  if torch.cuda.is_available():
requirements.txt CHANGED
@@ -15,8 +15,6 @@ gradio
15
  gradio[mcp]
16
  fastmcp
17
  # MCP dependencies (required for Gemini MCP)
18
- # Install MCP SDK: pip install mcp
19
- # The MCP package provides Model Context Protocol server and client functionality
20
  mcp>=0.1.0
21
  nest-asyncio
22
  google-generativeai
@@ -28,12 +26,7 @@ spaces
28
  soundfile
29
  numpy<2.0.0
30
  setuptools>=65.0.0
31
- # TTS installation (OPTIONAL) - TTS features work without it
32
- # If you want TTS functionality, install manually due to pyworld build issues:
33
- # Option 1: pip install TTS --no-deps && pip install coqui-tts
34
- # Option 2: pip install TTS (may fail on pyworld, but TTS will work for most models without it)
35
- # The app will run without TTS - voice generation will be disabled
36
- # TTS
37
 
38
  # ASR (Automatic Speech Recognition) - Whisper for speech-to-text (via Hugging Face transformers)
39
  torchaudio
 
 
15
  gradio[mcp]
16
  fastmcp
17
  # MCP dependencies (required for Gemini MCP)
 
 
18
  mcp>=0.1.0
19
  nest-asyncio
20
  google-generativeai
 
26
  soundfile
27
  numpy<2.0.0
28
  setuptools>=65.0.0
 
 
 
 
 
 
29
 
30
  # ASR (Automatic Speech Recognition) - Whisper for speech-to-text (via Hugging Face transformers)
31
  torchaudio
32
+ snac
ui.py CHANGED
@@ -16,6 +16,7 @@ from models import (
16
  initialize_tts_model,
17
  initialize_whisper_model,
18
  TTS_AVAILABLE,
 
19
  WHISPER_AVAILABLE,
20
  )
21
  from logger import logger
@@ -362,7 +363,7 @@ def create_demo():
362
  status_lines.append(f"⏳ MedSwin ({model_name}): loading...")
363
 
364
  # TTS model status (only show if available or if there's an issue)
365
- if TTS_AVAILABLE:
366
  if config.global_tts_model is not None:
367
  status_lines.append("✅ TTS (maya1): loaded and ready")
368
  else:
@@ -402,7 +403,7 @@ def create_demo():
402
  status_lines.append(f"⚠️ MedSwin ({model_name}): not loaded")
403
 
404
  # TTS model status (only show if available and loaded)
405
- if TTS_AVAILABLE:
406
  if config.global_tts_model is not None:
407
  status_lines.append("✅ TTS (maya1): loaded and ready")
408
  # Don't show if TTS library available but model not loaded (optional feature)
 
16
  initialize_tts_model,
17
  initialize_whisper_model,
18
  TTS_AVAILABLE,
19
+ SNAC_AVAILABLE,
20
  WHISPER_AVAILABLE,
21
  )
22
  from logger import logger
 
363
  status_lines.append(f"⏳ MedSwin ({model_name}): loading...")
364
 
365
  # TTS model status (only show if available or if there's an issue)
366
+ if SNAC_AVAILABLE:
367
  if config.global_tts_model is not None:
368
  status_lines.append("✅ TTS (maya1): loaded and ready")
369
  else:
 
403
  status_lines.append(f"⚠️ MedSwin ({model_name}): not loaded")
404
 
405
  # TTS model status (only show if available and loaded)
406
+ if SNAC_AVAILABLE:
407
  if config.global_tts_model is not None:
408
  status_lines.append("✅ TTS (maya1): loaded and ready")
409
  # Don't show if TTS library available but model not loaded (optional feature)
voice.py CHANGED
@@ -8,7 +8,22 @@ import numpy as np
8
  from logger import logger
9
  from client import MCP_AVAILABLE, call_agent, get_mcp_session, get_cached_mcp_tools
10
  import config
11
- from models import TTS_AVAILABLE, WHISPER_AVAILABLE, initialize_tts_model, initialize_whisper_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  import spaces
13
 
14
  try:
@@ -408,7 +423,52 @@ def _generate_speech_via_mcp(text: str):
408
  logger.warning(f"MCP TTS error (sync wrapper): {e}")
409
  return None
410
 
411
- def _generate_speech_with_gpu(text: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  """Internal GPU-decorated function for TTS generation when TTS is available."""
413
  if config.global_tts_model is None:
414
  logger.info("[TTS] TTS model not loaded, initializing...")
@@ -418,13 +478,83 @@ def _generate_speech_with_gpu(text: str):
418
  logger.error("[TTS] TTS model not available. Please check dependencies.")
419
  return None
420
 
 
 
 
 
 
421
  try:
422
- logger.info("[TTS] Running TTS generation...")
423
- wav = config.global_tts_model.tts(text)
424
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
425
- sf.write(tmp_file.name, wav, samplerate=22050)
426
- logger.info(f"[TTS] Speech generated successfully: {tmp_file.name}")
427
- return tmp_file.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  except Exception as e:
429
  logger.error(f"[TTS] TTS error (local maya1): {e}")
430
  import traceback
@@ -452,15 +582,16 @@ def generate_speech(text: str):
452
  logger.info(f"[TTS] Generating speech for text: {text[:50]}...")
453
 
454
  # Check TTS availability first - avoid GPU allocation if not available
455
- if not TTS_AVAILABLE:
456
- logger.warning("[TTS] TTS library not installed. Trying MCP fallback...")
 
457
  # Try MCP-based TTS if available (doesn't require GPU)
458
  audio_path = _generate_speech_via_mcp(text)
459
  if audio_path:
460
  logger.info(f"[TTS] ✅ Generated via MCP fallback: {audio_path}")
461
  return audio_path
462
  else:
463
- logger.error("[TTS] ❌ TTS library not installed and MCP fallback failed. Please install TTS: pip install TTS --no-deps && pip install coqui-tts")
464
  return None
465
 
466
  # TTS is available - use GPU-decorated function
 
8
  from logger import logger
9
  from client import MCP_AVAILABLE, call_agent, get_mcp_session, get_cached_mcp_tools
10
  import config
11
+ from models import TTS_AVAILABLE, SNAC_AVAILABLE, WHISPER_AVAILABLE, initialize_tts_model, initialize_whisper_model
12
+
13
+ # Maya1 constants (from maya1 docs)
14
+ CODE_START_TOKEN_ID = 128257
15
+ CODE_END_TOKEN_ID = 128258
16
+ CODE_TOKEN_OFFSET = 128266
17
+ SNAC_MIN_ID = 128266
18
+ SNAC_MAX_ID = 156937
19
+ SOH_ID = 128259
20
+ EOH_ID = 128260
21
+ SOA_ID = 128261
22
+ TEXT_EOT_ID = 128009
23
+ AUDIO_SAMPLE_RATE = 24000
24
+
25
+ # Default voice description for Maya1
26
+ DEFAULT_VOICE_DESCRIPTION = "Realistic male voice in the 30s age with a american accent. Normal pitch, warm timbre, conversational pacing, neutral tone delivery at medium intensity, podcast domain, narrator role, neutral delivery"
27
  import spaces
28
 
29
  try:
 
423
  logger.warning(f"MCP TTS error (sync wrapper): {e}")
424
  return None
425
 
426
+ def build_maya1_prompt(tokenizer, description: str, text: str) -> str:
427
+ """Build formatted prompt for Maya1."""
428
+ soh_token = tokenizer.decode([SOH_ID])
429
+ eoh_token = tokenizer.decode([EOH_ID])
430
+ soa_token = tokenizer.decode([SOA_ID])
431
+ sos_token = tokenizer.decode([CODE_START_TOKEN_ID])
432
+ eot_token = tokenizer.decode([TEXT_EOT_ID])
433
+ bos_token = tokenizer.bos_token
434
+
435
+ formatted_text = f'<description="{description}"> {text}'
436
+ prompt = (
437
+ soh_token + bos_token + formatted_text + eot_token +
438
+ eoh_token + soa_token + sos_token
439
+ )
440
+ return prompt
441
+
442
+ def unpack_snac_from_7(snac_tokens: list) -> list:
443
+ """Unpack 7-token SNAC frames to 3 hierarchical levels."""
444
+ if snac_tokens and snac_tokens[-1] == CODE_END_TOKEN_ID:
445
+ snac_tokens = snac_tokens[:-1]
446
+
447
+ frames = len(snac_tokens) // 7
448
+ snac_tokens = snac_tokens[:frames * 7]
449
+
450
+ if frames == 0:
451
+ return [[], [], []]
452
+
453
+ l1, l2, l3 = [], [], []
454
+
455
+ for i in range(frames):
456
+ slots = snac_tokens[i*7:(i+1)*7]
457
+ l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096)
458
+ l2.extend([
459
+ (slots[1] - CODE_TOKEN_OFFSET) % 4096,
460
+ (slots[4] - CODE_TOKEN_OFFSET) % 4096,
461
+ ])
462
+ l3.extend([
463
+ (slots[2] - CODE_TOKEN_OFFSET) % 4096,
464
+ (slots[3] - CODE_TOKEN_OFFSET) % 4096,
465
+ (slots[5] - CODE_TOKEN_OFFSET) % 4096,
466
+ (slots[6] - CODE_TOKEN_OFFSET) % 4096,
467
+ ])
468
+
469
+ return [l1, l2, l3]
470
+
471
+ def _generate_speech_with_gpu(text: str, description: str = None):
472
  """Internal GPU-decorated function for TTS generation when TTS is available."""
473
  if config.global_tts_model is None:
474
  logger.info("[TTS] TTS model not loaded, initializing...")
 
478
  logger.error("[TTS] TTS model not available. Please check dependencies.")
479
  return None
480
 
481
+ # Check if it's the new Maya1 format (dictionary) or old format
482
+ if not isinstance(config.global_tts_model, dict):
483
+ logger.error("[TTS] TTS model format is incorrect. Expected dictionary with model, tokenizer, snac_model.")
484
+ return None
485
+
486
  try:
487
+ model = config.global_tts_model["model"]
488
+ tokenizer = config.global_tts_model["tokenizer"]
489
+ snac_model = config.global_tts_model["snac_model"]
490
+
491
+ # Use default description if not provided
492
+ if description is None:
493
+ description = DEFAULT_VOICE_DESCRIPTION
494
+
495
+ logger.info("[TTS] Running Maya1 TTS generation...")
496
+
497
+ # Build prompt
498
+ prompt = build_maya1_prompt(tokenizer, description, text)
499
+ inputs = tokenizer(prompt, return_tensors="pt")
500
+
501
+ if torch.cuda.is_available():
502
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
503
+
504
+ # Generate tokens
505
+ with torch.inference_mode():
506
+ outputs = model.generate(
507
+ **inputs,
508
+ max_new_tokens=1500,
509
+ min_new_tokens=28,
510
+ temperature=0.4,
511
+ top_p=0.9,
512
+ repetition_penalty=1.1,
513
+ do_sample=True,
514
+ eos_token_id=CODE_END_TOKEN_ID,
515
+ pad_token_id=tokenizer.pad_token_id,
516
+ )
517
+
518
+ # Extract SNAC tokens
519
+ generated_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist()
520
+
521
+ # Find EOS and extract SNAC codes
522
+ eos_idx = generated_ids.index(CODE_END_TOKEN_ID) if CODE_END_TOKEN_ID in generated_ids else len(generated_ids)
523
+ snac_tokens = [t for t in generated_ids[:eos_idx] if SNAC_MIN_ID <= t <= SNAC_MAX_ID]
524
+
525
+ if len(snac_tokens) < 7:
526
+ logger.error(f"[TTS] Not enough tokens generated ({len(snac_tokens)}). Try different text or increase max_tokens.")
527
+ return None
528
+
529
+ # Unpack and decode
530
+ levels = unpack_snac_from_7(snac_tokens)
531
+ frames = len(levels[0])
532
+
533
+ device = "cuda" if torch.cuda.is_available() else "cpu"
534
+ codes_tensor = [torch.tensor(level, dtype=torch.long, device=device).unsqueeze(0) for level in levels]
535
+
536
+ with torch.inference_mode():
537
+ z_q = snac_model.quantizer.from_codes(codes_tensor)
538
+ audio = snac_model.decoder(z_q)[0, 0].cpu().numpy()
539
+
540
+ # Trim warmup
541
+ if len(audio) > 2048:
542
+ audio = audio[2048:]
543
+
544
+ # Convert to WAV and save to temporary file
545
+ audio_int16 = (audio * 32767).astype(np.int16)
546
+
547
+ # Create temporary file
548
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
549
+ tmp_path = tmp_file.name
550
+
551
+ # Save audio
552
+ sf.write(tmp_path, audio_int16, AUDIO_SAMPLE_RATE)
553
+
554
+ duration = len(audio) / AUDIO_SAMPLE_RATE
555
+ logger.info(f"[TTS] ✅ Speech generated successfully: {tmp_path} ({duration:.2f}s)")
556
+ return tmp_path
557
+
558
  except Exception as e:
559
  logger.error(f"[TTS] TTS error (local maya1): {e}")
560
  import traceback
 
582
  logger.info(f"[TTS] Generating speech for text: {text[:50]}...")
583
 
584
  # Check TTS availability first - avoid GPU allocation if not available
585
+ # Use SNAC_AVAILABLE for Maya1, but keep TTS_AVAILABLE check for backward compatibility
586
+ if not SNAC_AVAILABLE:
587
+ logger.warning("[TTS] SNAC library not installed (required for Maya1). Trying MCP fallback...")
588
  # Try MCP-based TTS if available (doesn't require GPU)
589
  audio_path = _generate_speech_via_mcp(text)
590
  if audio_path:
591
  logger.info(f"[TTS] ✅ Generated via MCP fallback: {audio_path}")
592
  return audio_path
593
  else:
594
+ logger.error("[TTS] ❌ SNAC library not installed and MCP fallback failed. Please install: pip install snac")
595
  return None
596
 
597
  # TTS is available - use GPU-decorated function