vibingvoice commited on
Commit
ff79aeb
·
verified ·
1 Parent(s): cd0b70a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -0
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import librosa
7
+ import logging
8
+ import gradio as gr
9
+ import tempfile
10
+ from typing import Dict, Optional, List
11
+
12
+ # --- 1. Setup Environment ---
13
+
14
+ # Add the project root to the Python path to allow importing local modules
15
+ project_root = os.path.dirname(os.path.abspath(__file__))
16
+ if project_root not in sys.path:
17
+ sys.path.insert(0, project_root)
18
+
19
+ # Configure logging to see VibeVoice messages
20
+ logging.basicConfig(level=logging.INFO, format='[%(name)s] %(message)s')
21
+ logger = logging.getLogger("VibeVoiceGradio")
22
+
23
+ # Mock ComfyUI's folder_paths module for model caching
24
+ class MockFolderPaths:
25
+ def get_folder_paths(self, folder_name):
26
+ if folder_name == "checkpoints":
27
+ models_dir = os.path.join(project_root, "models")
28
+ os.makedirs(models_dir, exist_ok=True)
29
+ return [models_dir]
30
+ return []
31
+
32
+ sys.modules['folder_paths'] = MockFolderPaths()
33
+
34
+ # Import the node class after setting up the environment
35
+ # We use MultiSpeakerNode as it can handle single-speaker text too.
36
+ from nodes.multi_speaker_node import VibeVoiceMultipleSpeakersNode
37
+
38
+ # --- 2. Load Model Globally ---
39
+
40
+ logger.info("Initializing VibeVoice node...")
41
+ # We use the multi-speaker node as it can handle single-speaker cases gracefully.
42
+ # This instance will hold the model in memory for all Gradio calls.
43
+ vibevoice_node = VibeVoiceMultipleSpeakersNode()
44
+
45
+ try:
46
+ logger.info("Loading VibeVoice-Large model. This may take a while on the first run...")
47
+ # Pre-load the model into the node instance.
48
+ vibevoice_node.load_model(
49
+ model_name='VibeVoice-Large',
50
+ model_path='aoi-ot/VibeVoice-Large',
51
+ attention_type='auto'
52
+ )
53
+ logger.info("VibeVoice-Large model loaded successfully!")
54
+ except Exception as e:
55
+ logger.error(f"Failed to load the model: {e}")
56
+ logger.error("Please ensure you have an internet connection for the first run and sufficient VRAM.")
57
+ sys.exit(1)
58
+
59
+
60
+ # --- 3. Helper Functions ---
61
+
62
+ def load_audio_for_node(file_path: Optional[str]) -> Optional[Dict]:
63
+ """Loads an audio file from a path and formats it for the VibeVoice node."""
64
+ if file_path is None:
65
+ return None
66
+ try:
67
+ waveform, sr = librosa.load(file_path, sr=24000, mono=True)
68
+ waveform_tensor = torch.from_numpy(waveform).float().unsqueeze(0).unsqueeze(0)
69
+ return {"waveform": waveform_tensor, "sample_rate": 24000}
70
+ except Exception as e:
71
+ logger.error(f"Failed to load audio file {file_path}: {e}")
72
+ return None
73
+
74
+ def save_audio_to_tempfile(audio_dict: Dict) -> Optional[str]:
75
+ """Saves the node's audio output to a temporary WAV file for Gradio."""
76
+ if not audio_dict or "waveform" not in audio_dict:
77
+ logger.error("Invalid audio dictionary received from node.")
78
+ return None
79
+
80
+ waveform_tensor = audio_dict["waveform"]
81
+ sample_rate = audio_dict["sample_rate"]
82
+
83
+ waveform_np = waveform_tensor.squeeze().cpu().numpy()
84
+
85
+ # Create a temporary file
86
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
87
+ sf.write(tmpfile.name, waveform_np, sample_rate)
88
+ return tmpfile.name
89
+
90
+ # --- 4. Gradio Core Logic ---
91
+
92
+ def generate_speech_gradio(
93
+ text: str,
94
+ speaker1_audio_path: Optional[str],
95
+ speaker2_audio_path: Optional[str],
96
+ speaker3_audio_path: Optional[str],
97
+ speaker4_audio_path: Optional[str],
98
+ seed: int,
99
+ diffusion_steps: int,
100
+ cfg_scale: float,
101
+ use_sampling: bool,
102
+ temperature: float,
103
+ top_p: float,
104
+ progress=gr.Progress(track_tqdm=True)
105
+ ):
106
+ """The main function that Gradio will call to generate speech."""
107
+ if not text or not text.strip():
108
+ raise gr.Error("Please provide some text to generate.")
109
+
110
+ progress(0, desc="Processing audio inputs...")
111
+ logger.info("Processing user inputs...")
112
+
113
+ # Load uploaded voices
114
+ speaker_voices = [
115
+ load_audio_for_node(speaker1_audio_path),
116
+ load_audio_for_node(speaker2_audio_path),
117
+ load_audio_for_node(speaker3_audio_path),
118
+ load_audio_for_node(speaker4_audio_path),
119
+ ]
120
+
121
+ progress(0.2, desc="Generating speech... (this can take a moment)")
122
+ logger.info("Calling VibeVoice model to generate speech...")
123
+
124
+ try:
125
+ # Call the generate_speech method on our globally loaded node
126
+ audio_output_tuple = vibevoice_node.generate_speech(
127
+ text=text,
128
+ model='VibeVoice-Large',
129
+ attention_type='auto',
130
+ free_memory_after_generate=False, # Keep model in memory for next call
131
+ diffusion_steps=int(diffusion_steps),
132
+ seed=int(seed),
133
+ cfg_scale=cfg_scale,
134
+ use_sampling=use_sampling,
135
+ speaker1_voice=speaker_voices[0],
136
+ speaker2_voice=speaker_voices[1],
137
+ speaker3_voice=speaker_voices[2],
138
+ speaker4_voice=speaker_voices[3],
139
+ temperature=temperature,
140
+ top_p=top_p
141
+ )
142
+ except Exception as e:
143
+ logger.error(f"Error during speech generation: {e}")
144
+ raise gr.Error(f"An error occurred during generation: {e}")
145
+
146
+ progress(0.9, desc="Saving audio file...")
147
+ logger.info("Generation complete. Saving audio output.")
148
+
149
+ # Save the output to a temporary file for Gradio to serve
150
+ output_audio_path = save_audio_to_tempfile(audio_output_tuple[0])
151
+
152
+ if output_audio_path is None:
153
+ raise gr.Error("Failed to process the generated audio.")
154
+
155
+ return output_audio_path
156
+
157
+ # --- 5. Gradio UI Layout ---
158
+
159
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
160
+ gr.Markdown(
161
+ "# VibeVoice Text-to-Speech Demo\n"
162
+ "Generate multi-speaker conversations with optional voice cloning using Microsoft's VibeVoice-Large model."
163
+ )
164
+
165
+ with gr.Row():
166
+ with gr.Column(scale=2):
167
+ text_input = gr.Textbox(
168
+ label="Text Input",
169
+ placeholder=(
170
+ "Enter text using speaker tags like [1]:, [2]:, etc.\n\n"
171
+ "[1]: Hello, I'm the first speaker.\n"
172
+ "[2]: Hi there, I'm the second! How are you?\n"
173
+ "[1]: I'm doing great, thanks for asking!"
174
+ ),
175
+ lines=8,
176
+ max_lines=20
177
+ )
178
+ with gr.Accordion("Upload Speaker Voices (Optional)", open=False):
179
+ gr.Markdown("Upload a short audio clip (3-30 seconds, clear audio) for each speaker you want to clone.")
180
+ with gr.Row():
181
+ speaker1_audio = gr.Audio(label="Speaker 1 Voice", type="filepath")
182
+ speaker2_audio = gr.Audio(label="Speaker 2 Voice", type="filepath")
183
+ with gr.Row():
184
+ speaker3_audio = gr.Audio(label="Speaker 3 Voice", type="filepath")
185
+ speaker4_audio = gr.Audio(label="Speaker 4 Voice", type="filepath")
186
+
187
+ with gr.Accordion("Advanced Options", open=False):
188
+ seed = gr.Slider(label="Seed", minimum=0, maximum=2**32-1, step=1, value=42, interactive=True)
189
+ diffusion_steps = gr.Slider(label="Diffusion Steps", minimum=5, maximum=100, step=1, value=20, interactive=True, info="More steps = better quality, but slower.")
190
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=0.5, maximum=3.5, step=0.05, value=1.3, interactive=True, info="Guidance scale.")
191
+ use_sampling = gr.Checkbox(label="Use Sampling", value=False, interactive=True, info="Enable for more varied, less deterministic output.")
192
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.05, value=0.95, interactive=True, info="Only used when sampling is enabled.")
193
+ top_p = gr.Slider(label="Top P", minimum=0.1, maximum=1.0, step=0.05, value=0.95, interactive=True, info="Only used when sampling is enabled.")
194
+
195
+ with gr.Column(scale=1):
196
+ generate_button = gr.Button("Generate Speech", variant="primary")
197
+ audio_output = gr.Audio(label="Generated Speech", type="filepath", interactive=False)
198
+
199
+ inputs = [
200
+ text_input,
201
+ speaker1_audio, speaker2_audio, speaker3_audio, speaker4_audio,
202
+ seed, diffusion_steps, cfg_scale, use_sampling, temperature, top_p
203
+ ]
204
+
205
+ generate_button.click(
206
+ fn=generate_speech_gradio,
207
+ inputs=inputs,
208
+ outputs=audio_output
209
+ )
210
+
211
+ if __name__ == "__main__":
212
+ # Launch the Gradio app
213
+ demo.launch(share=True) # Add share=True to create a public link: demo.launch(share=True)