import os import sys import torch import numpy as np import soundfile as sf import librosa import logging import gradio as gr import tempfile from typing import Dict, Optional, List # --- 1. Setup Environment --- # Add the project root to the Python path to allow importing local modules project_root = os.path.dirname(os.path.abspath(__file__)) if project_root not in sys.path: sys.path.insert(0, project_root) # Configure logging to see VibeVoice messages logging.basicConfig(level=logging.INFO, format='[%(name)s] %(message)s') logger = logging.getLogger("VibeVoiceGradio") # Mock ComfyUI's folder_paths module for model caching class MockFolderPaths: def get_folder_paths(self, folder_name): if folder_name == "checkpoints": models_dir = os.path.join(project_root, "models") os.makedirs(models_dir, exist_ok=True) return [models_dir] return [] sys.modules['folder_paths'] = MockFolderPaths() # Import the node class after setting up the environment # We use MultiSpeakerNode as it can handle single-speaker text too. from nodes.multi_speaker_node import VibeVoiceMultipleSpeakersNode # --- 2. Load Model Globally --- logger.info("Initializing VibeVoice node...") # We use the multi-speaker node as it can handle single-speaker cases gracefully. # This instance will hold the model in memory for all Gradio calls. vibevoice_node = VibeVoiceMultipleSpeakersNode() try: logger.info("Loading VibeVoice-Large model. This may take a while on the first run...") # Pre-load the model into the node instance. vibevoice_node.load_model( model_name='VibeVoice-Large', model_path='aoi-ot/VibeVoice-Large', attention_type='auto' ) logger.info("VibeVoice-Large model loaded successfully!") except Exception as e: logger.error(f"Failed to load the model: {e}") logger.error("Please ensure you have an internet connection for the first run and sufficient VRAM.") sys.exit(1) # --- 3. Helper Functions --- def load_audio_for_node(file_path: Optional[str]) -> Optional[Dict]: """Loads an audio file from a path and formats it for the VibeVoice node.""" if file_path is None: return None try: waveform, sr = librosa.load(file_path, sr=24000, mono=True) waveform_tensor = torch.from_numpy(waveform).float().unsqueeze(0).unsqueeze(0) return {"waveform": waveform_tensor, "sample_rate": 24000} except Exception as e: logger.error(f"Failed to load audio file {file_path}: {e}") return None def save_audio_to_tempfile(audio_dict: Dict) -> Optional[str]: """Saves the node's audio output to a temporary WAV file for Gradio.""" if not audio_dict or "waveform" not in audio_dict: logger.error("Invalid audio dictionary received from node.") return None waveform_tensor = audio_dict["waveform"] sample_rate = audio_dict["sample_rate"] waveform_np = waveform_tensor.squeeze().cpu().numpy() # Create a temporary file with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: sf.write(tmpfile.name, waveform_np, sample_rate) return tmpfile.name # --- 4. Gradio Core Logic --- def generate_speech_gradio( text: str, speaker1_audio_path: Optional[str], speaker2_audio_path: Optional[str], speaker3_audio_path: Optional[str], speaker4_audio_path: Optional[str], seed: int, diffusion_steps: int, cfg_scale: float, use_sampling: bool, temperature: float, top_p: float, progress=gr.Progress(track_tqdm=True) ): """The main function that Gradio will call to generate speech.""" if not text or not text.strip(): raise gr.Error("Please provide some text to generate.") progress(0, desc="Processing audio inputs...") logger.info("Processing user inputs...") # Load uploaded voices speaker_voices = [ load_audio_for_node(speaker1_audio_path), load_audio_for_node(speaker2_audio_path), load_audio_for_node(speaker3_audio_path), load_audio_for_node(speaker4_audio_path), ] progress(0.2, desc="Generating speech... (this can take a moment)") logger.info("Calling VibeVoice model to generate speech...") try: # Call the generate_speech method on our globally loaded node audio_output_tuple = vibevoice_node.generate_speech( text=text, model='VibeVoice-Large', attention_type='auto', free_memory_after_generate=False, # Keep model in memory for next call diffusion_steps=int(diffusion_steps), seed=int(seed), cfg_scale=cfg_scale, use_sampling=use_sampling, speaker1_voice=speaker_voices[0], speaker2_voice=speaker_voices[1], speaker3_voice=speaker_voices[2], speaker4_voice=speaker_voices[3], temperature=temperature, top_p=top_p ) except Exception as e: logger.error(f"Error during speech generation: {e}") raise gr.Error(f"An error occurred during generation: {e}") progress(0.9, desc="Saving audio file...") logger.info("Generation complete. Saving audio output.") # Save the output to a temporary file for Gradio to serve output_audio_path = save_audio_to_tempfile(audio_output_tuple[0]) if output_audio_path is None: raise gr.Error("Failed to process the generated audio.") return output_audio_path # --- 5. Gradio UI Layout --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( "# VibeVoice Text-to-Speech Demo\n" "Generate multi-speaker conversations with optional voice cloning using Microsoft's VibeVoice-Large model." ) with gr.Row(): with gr.Column(scale=2): text_input = gr.Textbox( label="Text Input", placeholder=( "Enter text using speaker tags like [1]:, [2]:, etc.\n\n" "[1]: Hello, I'm the first speaker.\n" "[2]: Hi there, I'm the second! How are you?\n" "[1]: I'm doing great, thanks for asking!" ), lines=8, max_lines=20 ) with gr.Accordion("Upload Speaker Voices (Optional)", open=False): gr.Markdown("Upload a short audio clip (3-30 seconds, clear audio) for each speaker you want to clone.") with gr.Row(): speaker1_audio = gr.Audio(label="Speaker 1 Voice", type="filepath") speaker2_audio = gr.Audio(label="Speaker 2 Voice", type="filepath") with gr.Row(): speaker3_audio = gr.Audio(label="Speaker 3 Voice", type="filepath") speaker4_audio = gr.Audio(label="Speaker 4 Voice", type="filepath") with gr.Accordion("Advanced Options", open=False): seed = gr.Slider(label="Seed", minimum=0, maximum=2**32-1, step=1, value=42, interactive=True) diffusion_steps = gr.Slider(label="Diffusion Steps", minimum=5, maximum=100, step=1, value=20, interactive=True, info="More steps = better quality, but slower.") cfg_scale = gr.Slider(label="CFG Scale", minimum=0.5, maximum=3.5, step=0.05, value=1.3, interactive=True, info="Guidance scale.") use_sampling = gr.Checkbox(label="Use Sampling", value=False, interactive=True, info="Enable for more varied, less deterministic output.") temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.05, value=0.95, interactive=True, info="Only used when sampling is enabled.") top_p = gr.Slider(label="Top P", minimum=0.1, maximum=1.0, step=0.05, value=0.95, interactive=True, info="Only used when sampling is enabled.") with gr.Column(scale=1): generate_button = gr.Button("Generate Speech", variant="primary") audio_output = gr.Audio(label="Generated Speech", type="filepath", interactive=False) inputs = [ text_input, speaker1_audio, speaker2_audio, speaker3_audio, speaker4_audio, seed, diffusion_steps, cfg_scale, use_sampling, temperature, top_p ] generate_button.click( fn=generate_speech_gradio, inputs=inputs, outputs=audio_output ) if __name__ == "__main__": # Launch the Gradio app demo.launch(share=True) # Add share=True to create a public link: demo.launch(share=True)