Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import os | |
| import argparse | |
| import torch | |
| import logging | |
| import threading | |
| from datetime import datetime | |
| import torchaudio | |
| import librosa | |
| import soundfile as sf | |
| # ZeroGPU support | |
| try: | |
| import spaces | |
| ZEROGPU_AVAILABLE = True | |
| except ImportError: | |
| ZEROGPU_AVAILABLE = False | |
| # Create a dummy decorator for non-ZeroGPU environments | |
| class spaces: | |
| def GPU(duration=10): | |
| def decorator(func): | |
| return func | |
| return decorator | |
| # Project imports | |
| from tokenizer import StepAudioTokenizer | |
| from tts import StepAudioTTS | |
| from model_loader import ModelSource | |
| from config.edit_config import get_supported_edit_types | |
| from whisper_wrapper import WhisperWrapper | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Global variables for ZeroGPU-optimized loading | |
| encoder = None | |
| common_tts_engine = None | |
| whisper_asr = None | |
| args_global = None | |
| _model_lock = threading.Lock() # Thread lock for model initialization | |
| def initialize_whisper(): | |
| global whisper_asr | |
| if whisper_asr is not None: | |
| return | |
| with _model_lock: | |
| if whisper_asr is not None: | |
| return | |
| try: | |
| whisper_asr = WhisperWrapper() | |
| logger.info("β WhisperWrapper initialized for ASR") | |
| except Exception as e: | |
| logger.error(f"β Error loading Whisper ASR model: {e}") | |
| raise | |
| def initialize_models(): | |
| """Initialize models on first GPU call (ZeroGPU optimization: load inside GPU context)""" | |
| global encoder, common_tts_engine, args_global | |
| # Fast path: check if already initialized (without lock) | |
| if common_tts_engine is not None: | |
| return # Already initialized | |
| # Slow path: acquire lock and double-check | |
| with _model_lock: | |
| # Double-check pattern: another thread might have initialized while waiting for lock | |
| if common_tts_engine is not None: | |
| return # Already initialized by another thread | |
| if args_global is None: | |
| raise RuntimeError("Global args not set. Cannot initialize models.") | |
| try: | |
| logger.info("π Initializing models inside GPU context (first call)...") | |
| # Determine model source | |
| source_mapping = { | |
| "auto": ModelSource.AUTO, | |
| "local": ModelSource.LOCAL, | |
| "modelscope": ModelSource.MODELSCOPE, | |
| "huggingface": ModelSource.HUGGINGFACE | |
| } | |
| model_source = source_mapping[args_global.model_source] | |
| # Load StepAudioTokenizer (avoid CUDA initialization in main process) | |
| encoder = StepAudioTokenizer( | |
| os.path.join(args_global.model_path, "Step-Audio-Tokenizer"), | |
| model_source=model_source, | |
| funasr_model_id=args_global.tokenizer_model_id | |
| ) | |
| logger.info("β StepAudioTokenizer loaded") | |
| # Initialize common TTS engine (avoid CUDA initialization in main process) | |
| common_tts_engine = StepAudioTTS( | |
| os.path.join(args_global.model_path, "Step-Audio-EditX"), | |
| encoder, | |
| model_source=model_source, | |
| tts_model_id=args_global.tts_model_id, | |
| quantization_config=args_global.quantization, | |
| torch_dtype=torch_dtype, | |
| device_map=args_global.device_map, | |
| ) | |
| logger.info("β StepCommonAudioTTS loaded") | |
| print("Models initialized inside GPU context.") | |
| if ZEROGPU_AVAILABLE: | |
| logger.info("π‘ Models loaded inside GPU context - ready for inference") | |
| else: | |
| logger.info("π‘ Models loaded - ready for inference") | |
| except Exception as e: | |
| logger.error(f"β Error loading models: {e}") | |
| raise | |
| def get_model_config(): | |
| """Get model configuration without initializing GPU models""" | |
| if args_global is None: | |
| raise RuntimeError("Global args not set. Cannot get model config.") | |
| return { | |
| "encoder_path": os.path.join(args_global.model_path, "Step-Audio-Tokenizer"), | |
| "tts_path": os.path.join(args_global.model_path, "Step-Audio-EditX"), | |
| "model_source": args_global.model_source, | |
| "tokenizer_model_id": args_global.tokenizer_model_id, | |
| "tts_model_id": args_global.tts_model_id | |
| } | |
| def get_gpu_duration(audio_input, text_input, target_text, task_type, task_info): | |
| """Dynamic GPU duration based on whether models need initialization""" | |
| global common_tts_engine | |
| if common_tts_engine is None: | |
| # First call - need time for model loading (up to 5 minutes) | |
| return 120 # Maximum allowed duration for model initialization | |
| else: | |
| # Subsequent calls - only inference time needed | |
| return 120 # Standard inference duration | |
| # Dynamic duration based on model state | |
| def process_audio_with_gpu(audio_input, text_input, target_text, task_type, task_info): | |
| """Process audio using GPU (models are loaded inside GPU context to avoid main process errors)""" | |
| global common_tts_engine | |
| # Initialize models if not already loaded (inside GPU context to avoid main process errors) | |
| if common_tts_engine is None: | |
| print("Initializing common_tts_engine inside GPU context...") | |
| logger.info("π― GPU allocated for 300s (first call with model loading)...") | |
| initialize_models() | |
| logger.info("β Models loaded successfully inside GPU context") | |
| else: | |
| print("common_tts_engine already initialized.") | |
| logger.info("π― GPU allocated for 120s (inference with loaded models)...") | |
| try: | |
| # Use loaded models (first call may include loading time, subsequent calls are fast) | |
| if task_type == "clone": | |
| output_audio, sr = common_tts_engine.clone(audio_input, text_input, target_text) | |
| else: | |
| output_audio, sr = common_tts_engine.edit(audio_input, text_input, task_type, task_info, target_text) | |
| logger.info("β Audio processing completed") | |
| return output_audio, sr | |
| except Exception as e: | |
| logger.error(f"β Audio processing failed: {e}") | |
| raise | |
| # GPU automatically deallocated when function exits | |
| def transcribe_audio(audio_input, current_text): | |
| """Transcribe audio using Whisper ASR when prompt text is empty""" | |
| global whisper_asr | |
| # Only transcribe if current text is empty | |
| if current_text and current_text.strip(): | |
| return current_text # Keep existing text | |
| if not audio_input: | |
| return "" # No audio to transcribe | |
| if whisper_asr is None: | |
| initialize_whisper() | |
| try: | |
| # Transcribe audio | |
| transcribed_text = whisper_asr(audio_input) | |
| logger.info(f"Audio transcribed: {transcribed_text}") | |
| return transcribed_text | |
| except Exception as e: | |
| logger.error(f"Failed to transcribe audio: {e}") | |
| return "" | |
| # Save audio to temporary directory | |
| def save_audio(audio_type, audio_data, sr, tmp_dir): | |
| """Save audio data to a temporary file with timestamp""" | |
| current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| save_path = os.path.join(tmp_dir, audio_type, f"{current_time}.wav") | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| try: | |
| if isinstance(audio_data, torch.Tensor): | |
| torchaudio.save(save_path, audio_data, sr) | |
| else: | |
| sf.write(save_path, audio_data, sr) | |
| logger.debug(f"Audio saved to: {save_path}") | |
| return save_path | |
| except Exception as e: | |
| logger.error(f"Failed to save audio: {e}") | |
| raise | |
| class EditxTab: | |
| """Audio editing and voice cloning interface tab""" | |
| def __init__(self, args): | |
| self.args = args | |
| self.edit_type_list = list(get_supported_edit_types().keys()) | |
| self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}") | |
| self.enable_auto_transcribe = getattr(args, 'enable_auto_transcribe', False) | |
| def history_messages_to_show(self, messages): | |
| """Convert message history to gradio chatbot format""" | |
| show_msgs = [] | |
| for message in messages: | |
| edit_type = message['edit_type'] | |
| edit_info = message['edit_info'] | |
| source_text = message['source_text'] | |
| target_text = message['target_text'] | |
| raw_audio_part = message['raw_wave'] | |
| edit_audio_part = message['edit_wave'] | |
| type_str = f"{edit_type}-{edit_info}" if edit_info is not None else f"{edit_type}" | |
| show_msgs.extend([ | |
| {"role": "user", "content": f"δ»»ε‘η±»εοΌ{type_str}\nζζ¬οΌ{source_text}"}, | |
| {"role": "user", "content": gr.Audio(value=raw_audio_part, interactive=False)}, | |
| {"role": "assistant", "content": f"θΎεΊι³ι’οΌ\nζζ¬οΌ{target_text}"}, | |
| {"role": "assistant", "content": gr.Audio(value=edit_audio_part, interactive=False)} | |
| ]) | |
| return show_msgs | |
| def generate_clone(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state): | |
| """Generate cloned audio (models are loaded on first GPU call)""" | |
| self.logger.info("Starting voice cloning process") | |
| state['history_audio'] = [] | |
| state['history_messages'] = [] | |
| # Input validation | |
| if not prompt_text_input or prompt_text_input.strip() == "": | |
| error_msg = "[Error] Uploaded text cannot be empty." | |
| self.logger.error(error_msg) | |
| return [{"role": "user", "content": error_msg}], state | |
| if not prompt_audio_input: | |
| error_msg = "[Error] Uploaded audio cannot be empty." | |
| self.logger.error(error_msg) | |
| return [{"role": "user", "content": error_msg}], state | |
| if not generated_text or generated_text.strip() == "": | |
| error_msg = "[Error] Clone content cannot be empty." | |
| self.logger.error(error_msg) | |
| return [{"role": "user", "content": error_msg}], state | |
| if edit_type != "clone": | |
| error_msg = "[Error] CLONE button must use clone task." | |
| self.logger.error(error_msg) | |
| return [{"role": "user", "content": error_msg}], state | |
| try: | |
| # Use GPU inference with models loaded inside GPU context | |
| output_audio, output_sr = process_audio_with_gpu( | |
| prompt_audio_input, prompt_text_input, generated_text, "clone", edit_info | |
| ) | |
| if output_audio is not None and output_sr is not None: | |
| # Convert tensor to numpy if needed | |
| if isinstance(output_audio, torch.Tensor): | |
| audio_numpy = output_audio.cpu().numpy().squeeze() | |
| else: | |
| audio_numpy = output_audio | |
| # Load original audio for comparison | |
| input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input) | |
| # Create message for history | |
| cur_assistant_msg = { | |
| "edit_type": edit_type, | |
| "edit_info": edit_info, | |
| "source_text": prompt_text_input, | |
| "target_text": generated_text, | |
| "raw_wave": (input_sample_rate, input_audio_data_numpy), | |
| "edit_wave": (output_sr, audio_numpy), | |
| } | |
| state["history_audio"].append((output_sr, audio_numpy, generated_text)) | |
| state["history_messages"].append(cur_assistant_msg) | |
| show_msgs = self.history_messages_to_show(state["history_messages"]) | |
| self.logger.info("Voice cloning completed successfully") | |
| return show_msgs, state | |
| else: | |
| error_msg = "[Error] Clone failed" | |
| self.logger.error(error_msg) | |
| return [{"role": "user", "content": error_msg}], state | |
| except Exception as e: | |
| error_msg = f"[Error] Clone failed: {str(e)}" | |
| self.logger.error(error_msg) | |
| return [{"role": "user", "content": error_msg}], state | |
| def generate_edit(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state): | |
| """Generate edited audio (models are loaded on first GPU call)""" | |
| self.logger.info("Starting audio editing process") | |
| # Input validation | |
| if not prompt_audio_input: | |
| error_msg = "[Error] Uploaded audio cannot be empty." | |
| self.logger.error(error_msg) | |
| return [{"role": "user", "content": error_msg}], state | |
| try: | |
| # Determine which audio to use | |
| if len(state["history_audio"]) == 0: | |
| # First edit - use uploaded audio | |
| audio_to_edit = prompt_audio_input | |
| text_to_use = prompt_text_input | |
| self.logger.debug("Using prompt audio, no history found") | |
| else: | |
| # Use previous edited audio - save it to temp file first | |
| sample_rate, audio_numpy, previous_text = state["history_audio"][-1] | |
| temp_path = save_audio("temp", audio_numpy, sample_rate, self.args.tmp_dir) | |
| audio_to_edit = temp_path | |
| text_to_use = previous_text | |
| self.logger.debug(f"Using previous audio from history, count: {len(state['history_audio'])}") | |
| # For para-linguistic, use generated_text; otherwise use source text | |
| if edit_type not in {"paralinguistic"}: | |
| generated_text = text_to_use | |
| # Use GPU inference with models loaded inside GPU context | |
| output_audio, output_sr = process_audio_with_gpu( | |
| audio_to_edit, text_to_use, generated_text, edit_type, edit_info | |
| ) | |
| if output_audio is not None and output_sr is not None: | |
| # Convert tensor to numpy if needed | |
| if isinstance(output_audio, torch.Tensor): | |
| audio_numpy = output_audio.cpu().numpy().squeeze() | |
| else: | |
| audio_numpy = output_audio | |
| # Load original audio for comparison | |
| if len(state["history_audio"]) == 0: | |
| input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input) | |
| else: | |
| input_sample_rate, input_audio_data_numpy, _ = state["history_audio"][-1] | |
| # Create message for history | |
| cur_assistant_msg = { | |
| "edit_type": edit_type, | |
| "edit_info": edit_info, | |
| "source_text": text_to_use, | |
| "target_text": generated_text, | |
| "raw_wave": (input_sample_rate, input_audio_data_numpy), | |
| "edit_wave": (output_sr, audio_numpy), | |
| } | |
| state["history_audio"].append((output_sr, audio_numpy, generated_text)) | |
| state["history_messages"].append(cur_assistant_msg) | |
| show_msgs = self.history_messages_to_show(state["history_messages"]) | |
| self.logger.info("Audio editing completed successfully") | |
| return show_msgs, state | |
| else: | |
| error_msg = "[Error] Edit failed" | |
| self.logger.error(error_msg) | |
| return [{"role": "user", "content": error_msg}], state | |
| except Exception as e: | |
| error_msg = f"[Error] Edit failed: {str(e)}" | |
| self.logger.error(error_msg) | |
| return [{"role": "user", "content": error_msg}], state | |
| def clear_history(self, state): | |
| """Clear conversation history""" | |
| state["history_messages"] = [] | |
| state["history_audio"] = [] | |
| return [], state | |
| def init_state(self): | |
| """Initialize conversation state""" | |
| return { | |
| "history_messages": [], | |
| "history_audio": [] | |
| } | |
| def register_components(self): | |
| """Register gradio components - maintaining exact layout from original""" | |
| with gr.Tab("Editx"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| self.model_input = gr.Textbox(label="Model Name", value="Step-Audio-EditX", scale=1) | |
| self.prompt_text_input = gr.Textbox(label="Prompt Text", value="", scale=1) | |
| self.prompt_audio_input = gr.Audio( | |
| sources=["upload", "microphone"], | |
| format="wav", | |
| type="filepath", | |
| label="Input Audio", | |
| ) | |
| self.generated_text = gr.Textbox(label="Target Text", lines=1, max_lines=200, max_length=1000) | |
| with gr.Column(): | |
| with gr.Row(): | |
| self.edit_type = gr.Dropdown(label="Task", choices=self.edit_type_list, value="clone") | |
| self.edit_info = gr.Dropdown(label="Sub-task", choices=[], value=None) | |
| self.chat_box = gr.Chatbot(label="History", type="messages", height=480*1) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| self.button_tts = gr.Button("CLONE", variant="primary") | |
| self.button_edit = gr.Button("EDIT", variant="primary") | |
| with gr.Column(): | |
| self.clean_history_submit = gr.Button("Clear History", variant="primary") | |
| gr.Markdown("---") | |
| gr.Markdown(""" | |
| **Button Description:** | |
| - CLONE: Synthesizes audio based on uploaded audio and text, only used for clone mode, will clear history information when used. | |
| - EDIT: Edits based on uploaded audio, or continues to stack edit effects based on the previous round of generated audio. | |
| """) | |
| gr.Markdown(""" | |
| **Operation Workflow:** | |
| - Upload the audio to be edited on the left side and fill in the corresponding text content of the audio; | |
| - If the task requires modifying text content (such as clone, para-linguistic), fill in the text to be synthesized in the "target text" field. For all other tasks, keep the uploaded audio text content unchanged; | |
| - Select tasks and subtasks on the right side (some tasks have no subtasks, such as vad, etc.); | |
| - Click the "CLONE" or "EDIT" button on the left side, and audio will be generated in the dialog box on the right side. | |
| """) | |
| gr.Markdown(""" | |
| **Para-linguistic Description:** | |
| - Supported tags include: [Breathing] [Laughter] [Surprise-oh] [Confirmation-en] [Uhm] [Surprise-ah] [Surprise-wa] [Sigh] [Question-ei] [Dissatisfaction-hnn] | |
| - Example: | |
| - Fill in "target text" field: "Great, the weather is so nice today." Click the "CLONE" button to get audio. | |
| - Change "target text" field to: "Great[Laughter], the weather is so nice today[Surprise-ah]." Click the "EDIT" button to get para-linguistic audio. | |
| """) | |
| def register_events(self): | |
| """Register event handlers""" | |
| # Create independent state for each session | |
| state = gr.State(self.init_state()) | |
| self.button_tts.click(self.generate_clone, | |
| inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state], | |
| outputs=[self.chat_box, state]) | |
| self.button_edit.click(self.generate_edit, | |
| inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state], | |
| outputs=[self.chat_box, state]) | |
| self.clean_history_submit.click(self.clear_history, inputs=[state], outputs=[self.chat_box, state]) | |
| self.edit_type.change( | |
| fn=self.update_edit_info, | |
| inputs=self.edit_type, | |
| outputs=self.edit_info, | |
| ) | |
| # Add audio transcription event only if enabled | |
| if self.enable_auto_transcribe: | |
| self.prompt_audio_input.change( | |
| fn=transcribe_audio, | |
| inputs=[self.prompt_audio_input, self.prompt_text_input], | |
| outputs=self.prompt_text_input, | |
| ) | |
| def update_edit_info(self, category): | |
| """Update sub-task dropdown based on main task selection""" | |
| category_items = get_supported_edit_types() | |
| choices = category_items.get(category, []) | |
| value = None if len(choices) == 0 else choices[0] | |
| return gr.Dropdown(label="Sub-task", choices=choices, value=value) | |
| def launch_demo(args, editx_tab): | |
| """Launch the gradio demo""" | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| title="ποΈ Step-Audio-EditX", | |
| css=""" | |
| :root { | |
| --font: "Helvetica Neue", Helvetica, Arial, sans-serif; | |
| --font-mono: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace; | |
| } | |
| """) as demo: | |
| gr.Markdown("## ποΈ Step-Audio-EditX") | |
| gr.Markdown("Audio Editing and Zero-Shot Cloning using Step-Audio-EditX") | |
| # Register components | |
| editx_tab.register_components() | |
| # Register events | |
| editx_tab.register_events() | |
| # Launch demo | |
| demo.queue().launch( | |
| server_name=args.server_name, | |
| server_port=args.server_port, | |
| share=args.share if hasattr(args, 'share') else False | |
| ) | |
| if __name__ == "__main__": | |
| # Parse command line arguments | |
| parser = argparse.ArgumentParser(description="Step-Audio Edit Demo") | |
| parser.add_argument("--model-path", type=str, default="stepfun-ai", help="Model path.") | |
| parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.") | |
| parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.") | |
| parser.add_argument("--tmp-dir", type=str, default="/tmp/gradio", help="Save path.") | |
| parser.add_argument("--share", action="store_true", help="Share gradio app.") | |
| # Multi-source loading support parameters | |
| parser.add_argument( | |
| "--model-source", | |
| type=str, | |
| default="huggingface", | |
| choices=["auto", "local", "modelscope", "huggingface"], | |
| help="Model source: auto (detect automatically), local, modelscope, or huggingface" | |
| ) | |
| parser.add_argument( | |
| "--tokenizer-model-id", | |
| type=str, | |
| default="dengcunqin/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online", | |
| help="Tokenizer model ID for online loading" | |
| ) | |
| parser.add_argument( | |
| "--tts-model-id", | |
| type=str, | |
| default=None, | |
| help="TTS model ID for online loading (if different from model-path)" | |
| ) | |
| parser.add_argument( | |
| "--quantization", | |
| type=str, | |
| default=None, | |
| choices=["int4", "int8"], | |
| help="Enable quantization for the TTS model to reduce memory usage." | |
| "Choices: int4 (online), int8 (online)." | |
| "When quantization is enabled, data types are handled automatically by the quantization library." | |
| ) | |
| parser.add_argument( | |
| "--torch-dtype", | |
| type=str, | |
| default="bfloat16", | |
| choices=["float16", "bfloat16", "float32"], | |
| help="PyTorch data type for model operations. This setting only applies when quantization is disabled. " | |
| "When quantization is enabled, data types are managed automatically." | |
| ) | |
| parser.add_argument( | |
| "--device-map", | |
| type=str, | |
| default="cuda", | |
| help="Device mapping for model loading (default: cuda)" | |
| ) | |
| parser.add_argument( | |
| "--enable-auto-transcribe", | |
| action="store_true", | |
| help="Enable automatic audio transcription when uploading audio files (default: disabled)" | |
| ) | |
| parser.set_defaults(enable_auto_transcribe=True) | |
| args = parser.parse_args() | |
| # Store args globally for model configuration | |
| args_global = args | |
| logger.info(f"Configuration loaded:") | |
| # Map string arguments to actual types | |
| source_mapping = { | |
| "auto": ModelSource.AUTO, | |
| "local": ModelSource.LOCAL, | |
| "modelscope": ModelSource.MODELSCOPE, | |
| "huggingface": ModelSource.HUGGINGFACE | |
| } | |
| model_source = source_mapping[args.model_source] | |
| # Map torch dtype string to actual torch dtype | |
| dtype_mapping = { | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "float32": torch.float32 | |
| } | |
| torch_dtype = dtype_mapping[args.torch_dtype] | |
| logger.info(f"Loading models with source: {args.model_source}") | |
| logger.info(f"Model path: {args.model_path}") | |
| logger.info(f"Tokenizer model ID: {args.tokenizer_model_id}") | |
| logger.info(f"Torch dtype: {args.torch_dtype}") | |
| logger.info(f"Device map: {args.device_map}") | |
| if args.tts_model_id: | |
| logger.info(f"TTS model ID: {args.tts_model_id}") | |
| if args.quantization: | |
| logger.info(f"π§ {args.quantization.upper()} quantization enabled") | |
| # Create EditxTab instance | |
| editx_tab = EditxTab(args) | |
| # Launch demo | |
| launch_demo(args, editx_tab) |