xieli
feat: fix desc
b233913
import gradio as gr
import os
import argparse
import torch
import logging
import threading
from datetime import datetime
import torchaudio
import librosa
import soundfile as sf
# ZeroGPU support
try:
import spaces
ZEROGPU_AVAILABLE = True
except ImportError:
ZEROGPU_AVAILABLE = False
# Create a dummy decorator for non-ZeroGPU environments
class spaces:
@staticmethod
def GPU(duration=10):
def decorator(func):
return func
return decorator
# Project imports
from tokenizer import StepAudioTokenizer
from tts import StepAudioTTS
from model_loader import ModelSource
from config.edit_config import get_supported_edit_types
from whisper_wrapper import WhisperWrapper
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Global variables for ZeroGPU-optimized loading
encoder = None
common_tts_engine = None
whisper_asr = None
args_global = None
_model_lock = threading.Lock() # Thread lock for model initialization
def initialize_whisper():
global whisper_asr
if whisper_asr is not None:
return
with _model_lock:
if whisper_asr is not None:
return
try:
whisper_asr = WhisperWrapper()
logger.info("βœ“ WhisperWrapper initialized for ASR")
except Exception as e:
logger.error(f"❌ Error loading Whisper ASR model: {e}")
raise
def initialize_models():
"""Initialize models on first GPU call (ZeroGPU optimization: load inside GPU context)"""
global encoder, common_tts_engine, args_global
# Fast path: check if already initialized (without lock)
if common_tts_engine is not None:
return # Already initialized
# Slow path: acquire lock and double-check
with _model_lock:
# Double-check pattern: another thread might have initialized while waiting for lock
if common_tts_engine is not None:
return # Already initialized by another thread
if args_global is None:
raise RuntimeError("Global args not set. Cannot initialize models.")
try:
logger.info("πŸš€ Initializing models inside GPU context (first call)...")
# Determine model source
source_mapping = {
"auto": ModelSource.AUTO,
"local": ModelSource.LOCAL,
"modelscope": ModelSource.MODELSCOPE,
"huggingface": ModelSource.HUGGINGFACE
}
model_source = source_mapping[args_global.model_source]
# Load StepAudioTokenizer (avoid CUDA initialization in main process)
encoder = StepAudioTokenizer(
os.path.join(args_global.model_path, "Step-Audio-Tokenizer"),
model_source=model_source,
funasr_model_id=args_global.tokenizer_model_id
)
logger.info("βœ“ StepAudioTokenizer loaded")
# Initialize common TTS engine (avoid CUDA initialization in main process)
common_tts_engine = StepAudioTTS(
os.path.join(args_global.model_path, "Step-Audio-EditX"),
encoder,
model_source=model_source,
tts_model_id=args_global.tts_model_id,
quantization_config=args_global.quantization,
torch_dtype=torch_dtype,
device_map=args_global.device_map,
)
logger.info("βœ“ StepCommonAudioTTS loaded")
print("Models initialized inside GPU context.")
if ZEROGPU_AVAILABLE:
logger.info("πŸ’‘ Models loaded inside GPU context - ready for inference")
else:
logger.info("πŸ’‘ Models loaded - ready for inference")
except Exception as e:
logger.error(f"❌ Error loading models: {e}")
raise
def get_model_config():
"""Get model configuration without initializing GPU models"""
if args_global is None:
raise RuntimeError("Global args not set. Cannot get model config.")
return {
"encoder_path": os.path.join(args_global.model_path, "Step-Audio-Tokenizer"),
"tts_path": os.path.join(args_global.model_path, "Step-Audio-EditX"),
"model_source": args_global.model_source,
"tokenizer_model_id": args_global.tokenizer_model_id,
"tts_model_id": args_global.tts_model_id
}
def get_gpu_duration(audio_input, text_input, target_text, task_type, task_info):
"""Dynamic GPU duration based on whether models need initialization"""
global common_tts_engine
if common_tts_engine is None:
# First call - need time for model loading (up to 5 minutes)
return 120 # Maximum allowed duration for model initialization
else:
# Subsequent calls - only inference time needed
return 120 # Standard inference duration
@spaces.GPU(duration=get_gpu_duration) # Dynamic duration based on model state
def process_audio_with_gpu(audio_input, text_input, target_text, task_type, task_info):
"""Process audio using GPU (models are loaded inside GPU context to avoid main process errors)"""
global common_tts_engine
# Initialize models if not already loaded (inside GPU context to avoid main process errors)
if common_tts_engine is None:
print("Initializing common_tts_engine inside GPU context...")
logger.info("🎯 GPU allocated for 300s (first call with model loading)...")
initialize_models()
logger.info("βœ… Models loaded successfully inside GPU context")
else:
print("common_tts_engine already initialized.")
logger.info("🎯 GPU allocated for 120s (inference with loaded models)...")
try:
# Use loaded models (first call may include loading time, subsequent calls are fast)
if task_type == "clone":
output_audio, sr = common_tts_engine.clone(audio_input, text_input, target_text)
else:
output_audio, sr = common_tts_engine.edit(audio_input, text_input, task_type, task_info, target_text)
logger.info("βœ… Audio processing completed")
return output_audio, sr
except Exception as e:
logger.error(f"❌ Audio processing failed: {e}")
raise
# GPU automatically deallocated when function exits
@spaces.GPU(duration=30)
def transcribe_audio(audio_input, current_text):
"""Transcribe audio using Whisper ASR when prompt text is empty"""
global whisper_asr
# Only transcribe if current text is empty
if current_text and current_text.strip():
return current_text # Keep existing text
if not audio_input:
return "" # No audio to transcribe
if whisper_asr is None:
initialize_whisper()
try:
# Transcribe audio
transcribed_text = whisper_asr(audio_input)
logger.info(f"Audio transcribed: {transcribed_text}")
return transcribed_text
except Exception as e:
logger.error(f"Failed to transcribe audio: {e}")
return ""
# Save audio to temporary directory
def save_audio(audio_type, audio_data, sr, tmp_dir):
"""Save audio data to a temporary file with timestamp"""
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
save_path = os.path.join(tmp_dir, audio_type, f"{current_time}.wav")
os.makedirs(os.path.dirname(save_path), exist_ok=True)
try:
if isinstance(audio_data, torch.Tensor):
torchaudio.save(save_path, audio_data, sr)
else:
sf.write(save_path, audio_data, sr)
logger.debug(f"Audio saved to: {save_path}")
return save_path
except Exception as e:
logger.error(f"Failed to save audio: {e}")
raise
class EditxTab:
"""Audio editing and voice cloning interface tab"""
def __init__(self, args):
self.args = args
self.edit_type_list = list(get_supported_edit_types().keys())
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
self.enable_auto_transcribe = getattr(args, 'enable_auto_transcribe', False)
def history_messages_to_show(self, messages):
"""Convert message history to gradio chatbot format"""
show_msgs = []
for message in messages:
edit_type = message['edit_type']
edit_info = message['edit_info']
source_text = message['source_text']
target_text = message['target_text']
raw_audio_part = message['raw_wave']
edit_audio_part = message['edit_wave']
type_str = f"{edit_type}-{edit_info}" if edit_info is not None else f"{edit_type}"
show_msgs.extend([
{"role": "user", "content": f"δ»»εŠ‘η±»εž‹οΌš{type_str}\nζ–‡ζœ¬οΌš{source_text}"},
{"role": "user", "content": gr.Audio(value=raw_audio_part, interactive=False)},
{"role": "assistant", "content": f"θΎ“ε‡ΊιŸ³ι’‘οΌš\nζ–‡ζœ¬οΌš{target_text}"},
{"role": "assistant", "content": gr.Audio(value=edit_audio_part, interactive=False)}
])
return show_msgs
def generate_clone(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state):
"""Generate cloned audio (models are loaded on first GPU call)"""
self.logger.info("Starting voice cloning process")
state['history_audio'] = []
state['history_messages'] = []
# Input validation
if not prompt_text_input or prompt_text_input.strip() == "":
error_msg = "[Error] Uploaded text cannot be empty."
self.logger.error(error_msg)
return [{"role": "user", "content": error_msg}], state
if not prompt_audio_input:
error_msg = "[Error] Uploaded audio cannot be empty."
self.logger.error(error_msg)
return [{"role": "user", "content": error_msg}], state
if not generated_text or generated_text.strip() == "":
error_msg = "[Error] Clone content cannot be empty."
self.logger.error(error_msg)
return [{"role": "user", "content": error_msg}], state
if edit_type != "clone":
error_msg = "[Error] CLONE button must use clone task."
self.logger.error(error_msg)
return [{"role": "user", "content": error_msg}], state
try:
# Use GPU inference with models loaded inside GPU context
output_audio, output_sr = process_audio_with_gpu(
prompt_audio_input, prompt_text_input, generated_text, "clone", edit_info
)
if output_audio is not None and output_sr is not None:
# Convert tensor to numpy if needed
if isinstance(output_audio, torch.Tensor):
audio_numpy = output_audio.cpu().numpy().squeeze()
else:
audio_numpy = output_audio
# Load original audio for comparison
input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input)
# Create message for history
cur_assistant_msg = {
"edit_type": edit_type,
"edit_info": edit_info,
"source_text": prompt_text_input,
"target_text": generated_text,
"raw_wave": (input_sample_rate, input_audio_data_numpy),
"edit_wave": (output_sr, audio_numpy),
}
state["history_audio"].append((output_sr, audio_numpy, generated_text))
state["history_messages"].append(cur_assistant_msg)
show_msgs = self.history_messages_to_show(state["history_messages"])
self.logger.info("Voice cloning completed successfully")
return show_msgs, state
else:
error_msg = "[Error] Clone failed"
self.logger.error(error_msg)
return [{"role": "user", "content": error_msg}], state
except Exception as e:
error_msg = f"[Error] Clone failed: {str(e)}"
self.logger.error(error_msg)
return [{"role": "user", "content": error_msg}], state
def generate_edit(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state):
"""Generate edited audio (models are loaded on first GPU call)"""
self.logger.info("Starting audio editing process")
# Input validation
if not prompt_audio_input:
error_msg = "[Error] Uploaded audio cannot be empty."
self.logger.error(error_msg)
return [{"role": "user", "content": error_msg}], state
try:
# Determine which audio to use
if len(state["history_audio"]) == 0:
# First edit - use uploaded audio
audio_to_edit = prompt_audio_input
text_to_use = prompt_text_input
self.logger.debug("Using prompt audio, no history found")
else:
# Use previous edited audio - save it to temp file first
sample_rate, audio_numpy, previous_text = state["history_audio"][-1]
temp_path = save_audio("temp", audio_numpy, sample_rate, self.args.tmp_dir)
audio_to_edit = temp_path
text_to_use = previous_text
self.logger.debug(f"Using previous audio from history, count: {len(state['history_audio'])}")
# For para-linguistic, use generated_text; otherwise use source text
if edit_type not in {"paralinguistic"}:
generated_text = text_to_use
# Use GPU inference with models loaded inside GPU context
output_audio, output_sr = process_audio_with_gpu(
audio_to_edit, text_to_use, generated_text, edit_type, edit_info
)
if output_audio is not None and output_sr is not None:
# Convert tensor to numpy if needed
if isinstance(output_audio, torch.Tensor):
audio_numpy = output_audio.cpu().numpy().squeeze()
else:
audio_numpy = output_audio
# Load original audio for comparison
if len(state["history_audio"]) == 0:
input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input)
else:
input_sample_rate, input_audio_data_numpy, _ = state["history_audio"][-1]
# Create message for history
cur_assistant_msg = {
"edit_type": edit_type,
"edit_info": edit_info,
"source_text": text_to_use,
"target_text": generated_text,
"raw_wave": (input_sample_rate, input_audio_data_numpy),
"edit_wave": (output_sr, audio_numpy),
}
state["history_audio"].append((output_sr, audio_numpy, generated_text))
state["history_messages"].append(cur_assistant_msg)
show_msgs = self.history_messages_to_show(state["history_messages"])
self.logger.info("Audio editing completed successfully")
return show_msgs, state
else:
error_msg = "[Error] Edit failed"
self.logger.error(error_msg)
return [{"role": "user", "content": error_msg}], state
except Exception as e:
error_msg = f"[Error] Edit failed: {str(e)}"
self.logger.error(error_msg)
return [{"role": "user", "content": error_msg}], state
def clear_history(self, state):
"""Clear conversation history"""
state["history_messages"] = []
state["history_audio"] = []
return [], state
def init_state(self):
"""Initialize conversation state"""
return {
"history_messages": [],
"history_audio": []
}
def register_components(self):
"""Register gradio components - maintaining exact layout from original"""
with gr.Tab("Editx"):
with gr.Row():
with gr.Column():
self.model_input = gr.Textbox(label="Model Name", value="Step-Audio-EditX", scale=1)
self.prompt_text_input = gr.Textbox(label="Prompt Text", value="", scale=1)
self.prompt_audio_input = gr.Audio(
sources=["upload", "microphone"],
format="wav",
type="filepath",
label="Input Audio",
)
self.generated_text = gr.Textbox(label="Target Text", lines=1, max_lines=200, max_length=1000)
with gr.Column():
with gr.Row():
self.edit_type = gr.Dropdown(label="Task", choices=self.edit_type_list, value="clone")
self.edit_info = gr.Dropdown(label="Sub-task", choices=[], value=None)
self.chat_box = gr.Chatbot(label="History", type="messages", height=480*1)
with gr.Row():
with gr.Column():
with gr.Row():
self.button_tts = gr.Button("CLONE", variant="primary")
self.button_edit = gr.Button("EDIT", variant="primary")
with gr.Column():
self.clean_history_submit = gr.Button("Clear History", variant="primary")
gr.Markdown("---")
gr.Markdown("""
**Button Description:**
- CLONE: Synthesizes audio based on uploaded audio and text, only used for clone mode, will clear history information when used.
- EDIT: Edits based on uploaded audio, or continues to stack edit effects based on the previous round of generated audio.
""")
gr.Markdown("""
**Operation Workflow:**
- Upload the audio to be edited on the left side and fill in the corresponding text content of the audio;
- If the task requires modifying text content (such as clone, para-linguistic), fill in the text to be synthesized in the "target text" field. For all other tasks, keep the uploaded audio text content unchanged;
- Select tasks and subtasks on the right side (some tasks have no subtasks, such as vad, etc.);
- Click the "CLONE" or "EDIT" button on the left side, and audio will be generated in the dialog box on the right side.
""")
gr.Markdown("""
**Para-linguistic Description:**
- Supported tags include: [Breathing] [Laughter] [Surprise-oh] [Confirmation-en] [Uhm] [Surprise-ah] [Surprise-wa] [Sigh] [Question-ei] [Dissatisfaction-hnn]
- Example:
- Fill in "target text" field: "Great, the weather is so nice today." Click the "CLONE" button to get audio.
- Change "target text" field to: "Great[Laughter], the weather is so nice today[Surprise-ah]." Click the "EDIT" button to get para-linguistic audio.
""")
def register_events(self):
"""Register event handlers"""
# Create independent state for each session
state = gr.State(self.init_state())
self.button_tts.click(self.generate_clone,
inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state],
outputs=[self.chat_box, state])
self.button_edit.click(self.generate_edit,
inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state],
outputs=[self.chat_box, state])
self.clean_history_submit.click(self.clear_history, inputs=[state], outputs=[self.chat_box, state])
self.edit_type.change(
fn=self.update_edit_info,
inputs=self.edit_type,
outputs=self.edit_info,
)
# Add audio transcription event only if enabled
if self.enable_auto_transcribe:
self.prompt_audio_input.change(
fn=transcribe_audio,
inputs=[self.prompt_audio_input, self.prompt_text_input],
outputs=self.prompt_text_input,
)
def update_edit_info(self, category):
"""Update sub-task dropdown based on main task selection"""
category_items = get_supported_edit_types()
choices = category_items.get(category, [])
value = None if len(choices) == 0 else choices[0]
return gr.Dropdown(label="Sub-task", choices=choices, value=value)
def launch_demo(args, editx_tab):
"""Launch the gradio demo"""
with gr.Blocks(
theme=gr.themes.Soft(),
title="πŸŽ™οΈ Step-Audio-EditX",
css="""
:root {
--font: "Helvetica Neue", Helvetica, Arial, sans-serif;
--font-mono: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
}
""") as demo:
gr.Markdown("## πŸŽ™οΈ Step-Audio-EditX")
gr.Markdown("Audio Editing and Zero-Shot Cloning using Step-Audio-EditX")
# Register components
editx_tab.register_components()
# Register events
editx_tab.register_events()
# Launch demo
demo.queue().launch(
server_name=args.server_name,
server_port=args.server_port,
share=args.share if hasattr(args, 'share') else False
)
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser(description="Step-Audio Edit Demo")
parser.add_argument("--model-path", type=str, default="stepfun-ai", help="Model path.")
parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
parser.add_argument("--tmp-dir", type=str, default="/tmp/gradio", help="Save path.")
parser.add_argument("--share", action="store_true", help="Share gradio app.")
# Multi-source loading support parameters
parser.add_argument(
"--model-source",
type=str,
default="huggingface",
choices=["auto", "local", "modelscope", "huggingface"],
help="Model source: auto (detect automatically), local, modelscope, or huggingface"
)
parser.add_argument(
"--tokenizer-model-id",
type=str,
default="dengcunqin/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online",
help="Tokenizer model ID for online loading"
)
parser.add_argument(
"--tts-model-id",
type=str,
default=None,
help="TTS model ID for online loading (if different from model-path)"
)
parser.add_argument(
"--quantization",
type=str,
default=None,
choices=["int4", "int8"],
help="Enable quantization for the TTS model to reduce memory usage."
"Choices: int4 (online), int8 (online)."
"When quantization is enabled, data types are handled automatically by the quantization library."
)
parser.add_argument(
"--torch-dtype",
type=str,
default="bfloat16",
choices=["float16", "bfloat16", "float32"],
help="PyTorch data type for model operations. This setting only applies when quantization is disabled. "
"When quantization is enabled, data types are managed automatically."
)
parser.add_argument(
"--device-map",
type=str,
default="cuda",
help="Device mapping for model loading (default: cuda)"
)
parser.add_argument(
"--enable-auto-transcribe",
action="store_true",
help="Enable automatic audio transcription when uploading audio files (default: disabled)"
)
parser.set_defaults(enable_auto_transcribe=True)
args = parser.parse_args()
# Store args globally for model configuration
args_global = args
logger.info(f"Configuration loaded:")
# Map string arguments to actual types
source_mapping = {
"auto": ModelSource.AUTO,
"local": ModelSource.LOCAL,
"modelscope": ModelSource.MODELSCOPE,
"huggingface": ModelSource.HUGGINGFACE
}
model_source = source_mapping[args.model_source]
# Map torch dtype string to actual torch dtype
dtype_mapping = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32
}
torch_dtype = dtype_mapping[args.torch_dtype]
logger.info(f"Loading models with source: {args.model_source}")
logger.info(f"Model path: {args.model_path}")
logger.info(f"Tokenizer model ID: {args.tokenizer_model_id}")
logger.info(f"Torch dtype: {args.torch_dtype}")
logger.info(f"Device map: {args.device_map}")
if args.tts_model_id:
logger.info(f"TTS model ID: {args.tts_model_id}")
if args.quantization:
logger.info(f"πŸ”§ {args.quantization.upper()} quantization enabled")
# Create EditxTab instance
editx_tab = EditxTab(args)
# Launch demo
launch_demo(args, editx_tab)