Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import tempfile | |
| import torch | |
| import gradio as gr | |
| from faster_whisper import BatchedInferencePipeline, WhisperModel | |
| from pydub import AudioSegment, effects | |
| from pyannote.audio import Pipeline as DiarizationPipeline | |
| import opencc | |
| import spaces # zeroGPU support | |
| from funasr import AutoModel | |
| from funasr.utils.postprocess_utils import rich_transcription_postprocess | |
| from termcolor import cprint | |
| import time | |
| import torchaudio | |
| from pyannote.audio.pipelines.utils.hook import ProgressHook | |
| # —————— Model Lists —————— | |
| WHISPER_MODELS = [ | |
| "SoybeanMilk/faster-whisper-Breeze-ASR-25", | |
| "asadfgglie/faster-whisper-large-v3-zh-TW", | |
| "deepdml/faster-whisper-large-v3-turbo-ct2", | |
| "guillaumekln/faster-whisper-tiny", | |
| "Systran/faster-whisper-large-v3", | |
| "XA9/Belle-faster-whisper-large-v3-zh-punct", | |
| "guillaumekln/faster-whisper-medium", | |
| "guillaumekln/faster-whisper-small", | |
| "guillaumekln/faster-whisper-base", | |
| "Luigi/whisper-small-zh_tw-ct2", | |
| ] | |
| SENSEVOICE_MODELS = [ | |
| "FunAudioLLM/SenseVoiceSmall", | |
| "funasr/paraformer-zh", | |
| ] | |
| # —————— Language Options —————— | |
| WHISPER_LANGUAGES = [ | |
| "zh", "af","am","ar","as","az","ba","be","bg","bn","bo", | |
| "br","bs","ca","cs","cy","da","de","el","en","es","et", | |
| "eu","fa","fi","fo","fr","gl","gu","ha","haw","he","hi", | |
| "hr","ht","hu","hy","id","is","it","ja","jw","ka","kk", | |
| "km","kn","ko","la","lb","ln","lo","lt","lv","mg","mi", | |
| "mk","ml","mn","mr","ms","mt","my","ne","nl","nn","no", | |
| "oc","pa","pl","ps","pt","ro","ru","sa","sd","si","sk", | |
| "sl","sn","so","sq","sr","su","sv","sw","ta","te","tg", | |
| "th","tk","tl","tr","tt","uk","ur","uz","vi","yi","yo", | |
| "yue", "auto", | |
| ] | |
| SENSEVOICE_LANGUAGES = ["zh", "yue", "en", "ja", "ko", "auto", "nospeech"] | |
| # —————— Caches —————— | |
| whisper_pipes = {} | |
| sense_models = {} | |
| dar_pipe = None | |
| converter = opencc.OpenCC('s2t') | |
| # —————— Diarization Formatter —————— | |
| def format_diarization_html(snippets): | |
| palette = ["#e74c3c", "#3498db", "#27ae60", "#e67e22", "#9b59b6", "#16a085", "#f1c40f"] | |
| speaker_colors = {} | |
| html_lines = [] | |
| last_spk = None | |
| for s in snippets: | |
| if s.startswith("[") and "]" in s: | |
| spk, txt = s[1:].split("]", 1) | |
| spk, txt = spk.strip(), txt.strip() | |
| else: | |
| spk, txt = "", s.strip() | |
| # hide empty lines | |
| if not txt: | |
| continue | |
| # assign color if new speaker | |
| if spk not in speaker_colors: | |
| speaker_colors[spk] = palette[len(speaker_colors) % len(palette)] | |
| color = speaker_colors[spk] | |
| # simplify tag for same speaker | |
| if spk == last_spk: | |
| display = txt | |
| else: | |
| display = f"<strong>{spk}:</strong> {txt}" | |
| last_spk = spk | |
| html_lines.append( | |
| f"<p style='margin:4px 0; font-family:monospace; color:{color};'>{display}</p>" | |
| ) | |
| return "<div>" + "".join(html_lines) + "</div>" | |
| # —————— Helpers —————— | |
| # —————— Faster-Whisper Cache & Factory —————— | |
| _fwhisper_models: dict[tuple[str, str], WhisperModel] = {} | |
| def get_fwhisper_model(model_id: str, device: str) -> WhisperModel: | |
| """ | |
| Lazily load and cache WhisperModel(model_id) on 'cpu' or 'cuda:0'. | |
| Uses float16 on GPU and int8 on CPU for speed. | |
| """ | |
| key = (model_id, device) | |
| if key not in _fwhisper_models: | |
| compute_type = "float16" if device.startswith("cuda") else "int8" | |
| model = WhisperModel( | |
| model_id, | |
| device=device, | |
| compute_type=compute_type, | |
| ) | |
| _fwhisper_models[key] = BatchedInferencePipeline(model=model) | |
| return _fwhisper_models[key] | |
| def get_sense_model(model_id: str, device_str: str): | |
| key = (model_id, device_str) | |
| if key not in sense_models: | |
| sense_models[key] = AutoModel( | |
| model=model_id, | |
| vad_model="fsmn-vad", | |
| vad_kwargs={"max_single_segment_time": 300000}, | |
| device=device_str, | |
| ban_emo_unk=False, | |
| hub="hf", | |
| ) | |
| return sense_models[key] | |
| def get_diarization_pipe(): | |
| global dar_pipe | |
| if dar_pipe is None: | |
| token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") | |
| try: | |
| dar_pipe = DiarizationPipeline.from_pretrained( | |
| "pyannote/speaker-diarization-3.1", | |
| use_auth_token=token or True | |
| ) | |
| except Exception as e: | |
| print(f"Failed to load pyannote/speaker-diarization-3.1: {e}\nFalling back to pyannote/speaker-diarization@2.1.") | |
| dar_pipe = DiarizationPipeline.from_pretrained( | |
| "pyannote/speaker-diarization@2.1", | |
| use_auth_token=token or True | |
| ) | |
| return dar_pipe | |
| # —————— Whisper Transcription —————— | |
| def _transcribe_fwhisper_stream_common( | |
| model_id, | |
| language, | |
| audio_path, | |
| whisper_multilingual_en, | |
| enable_punct, | |
| backend, | |
| device, | |
| banner_text, | |
| banner_color | |
| ): | |
| """ | |
| Core generator for streaming transcription with accumulation using Faster-Whisper. | |
| Handles both CPU and CUDA backends; merges consecutive turns by the same speaker; | |
| strips injected trailing punctuation; and appends a Chinese period to new speaker turns if missing. | |
| Args: | |
| model_id: Whisper model identifier | |
| language: language code or "auto" | |
| audio_path: path to audio file | |
| whisper_multilingual_en: allow English in multilingual mode | |
| enable_punct: whether to append a Chinese period on new speaker turns when missing | |
| backend: "cpu" or "cuda" | |
| device: torch.device for model and diarizer | |
| banner_text: label for cprint (e.g. "CPU" or "CUDA") | |
| banner_color: color for cprint | |
| Yields: | |
| ("", format_diarization_html(snippets)) | |
| """ | |
| import re | |
| # Pattern to detect trailing punctuation | |
| end_punct_pattern = r'[。!?…~~\.\!?]+$' | |
| # Initialize whisper pipe | |
| pipe = get_fwhisper_model(model_id, backend) | |
| cprint(f'Whisper (faster-whisper) using {banner_text} [stream]', banner_color) | |
| # Load diarizer and audio | |
| diarizer = get_diarization_pipe() | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| if device.type == 'cuda': | |
| waveform = waveform.to(device) | |
| diarizer.to(device) | |
| # Run diarization | |
| with ProgressHook() as hook: | |
| diary = diarizer({"waveform": waveform, "sample_rate": sample_rate}, hook=hook) | |
| snippets = [] | |
| for turn, _, speaker in diary.itertracks(yield_label=True): | |
| # Extract audio segment | |
| start_ms = int(turn.start * 1000) | |
| end_ms = int(turn.end * 1000) | |
| segment = AudioSegment.from_file(audio_path)[start_ms:end_ms] | |
| # Transcribe with faster-whisper | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| segment = effects.normalize(segment) | |
| segment.export(tmp.name, format="wav") | |
| segments, _ = pipe.transcribe( | |
| tmp.name, | |
| beam_size=3, | |
| best_of=3, | |
| language=None if language == "auto" else language, | |
| vad_filter=True, | |
| batch_size=16, | |
| multilingual=whisper_multilingual_en, | |
| ) | |
| os.unlink(tmp.name) | |
| # Convert and clean text | |
| raw_text = "".join(s.text for s in segments).strip() | |
| text = converter.convert(raw_text) | |
| if text: | |
| tag = f"[{speaker}]" | |
| if enable_punct and not re.search(end_punct_pattern, text): | |
| text = f'{text}。' | |
| else: | |
| text = f'{text} ' | |
| if snippets and snippets[-1].startswith(tag): | |
| # Same speaker: merge | |
| prev_text = snippets[-1].split('] ', 1)[1] | |
| snippets[-1] = f"{tag} {prev_text}{text}" | |
| else: | |
| # New speaker: | |
| snippets.append(f"{tag} {text}") | |
| # Yield accumulated HTML | |
| yield "", format_diarization_html(snippets) | |
| return | |
| def _transcribe_fwhisper_cpu_stream( | |
| model_id, | |
| language, | |
| audio_path, | |
| whisper_multilingual_en, | |
| enable_punct | |
| ): | |
| """ | |
| CPU wrapper for Faster-Whisper streaming transcription. | |
| """ | |
| yield from _transcribe_fwhisper_stream_common( | |
| model_id, | |
| language, | |
| audio_path, | |
| whisper_multilingual_en, | |
| enable_punct, | |
| backend="cpu", | |
| device=torch.device('cpu'), | |
| banner_text="CPU", | |
| banner_color="red", | |
| ) | |
| def _transcribe_fwhisper_gpu_stream( | |
| model_id, | |
| language, | |
| audio_path, | |
| whisper_multilingual_en, | |
| enable_punct | |
| ): | |
| """ | |
| CUDA wrapper for Faster-Whisper streaming transcription. | |
| """ | |
| yield from _transcribe_fwhisper_stream_common( | |
| model_id, | |
| language, | |
| audio_path, | |
| whisper_multilingual_en, | |
| enable_punct, | |
| backend="cuda", | |
| device=torch.device('cuda'), | |
| banner_text="CUDA", | |
| banner_color="green", | |
| ) | |
| def transcribe_fwhisper_stream(model_id, language, audio_path, device_sel, whisper_multilingual_en, enable_punct): | |
| """Dispatch to CPU or GPU streaming generators, preserving two-value yields.""" | |
| if device_sel == "GPU" and torch.cuda.is_available(): | |
| yield from _transcribe_fwhisper_gpu_stream(model_id, language, audio_path, whisper_multilingual_en, enable_punct) | |
| else: | |
| yield from _transcribe_fwhisper_cpu_stream(model_id, language, audio_path, whisper_multilingual_en, enable_punct) | |
| # —————— SenseVoice Transcription —————— | |
| def _transcribe_sense_stream_common( | |
| model_id: str, | |
| language: str, | |
| audio_path: str, | |
| enable_punct: bool, | |
| backend: str, | |
| device: torch.device, | |
| banner_text: str, | |
| banner_color: str | |
| ): | |
| """ | |
| Core generator for SenseVoiceSmall streaming transcription. | |
| Handles CPU and CUDA; merges consecutive turns by the same speaker; | |
| strips injected trailing punctuation; appends a Chinese period to new speaker turns if missing. | |
| Args: | |
| model_id: model identifier for SenseVoiceSmall | |
| language: language code | |
| audio_path: path to audio file | |
| enable_punct: whether to keep ITN punctuation and append periods | |
| backend: device spec for get_sense_model ("cpu" or "cuda:0") | |
| device: torch.device for waveform & diarizer | |
| banner_text: label for console banner | |
| banner_color: color for console banner | |
| Yields: | |
| ("", format_diarization_html(snippets)) | |
| """ | |
| import re | |
| # Pattern to detect trailing punctuation | |
| end_punct_pattern = r'[。!?…~~\.\!?]+$' | |
| # Load model | |
| model = get_sense_model(model_id, backend) | |
| cprint(f'SenseVoiceSmall using {banner_text} [stream]', banner_color) | |
| # Prepare diarizer and audio | |
| diarizer = get_diarization_pipe() | |
| diarizer.to(device) | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| if device.type == 'cuda': | |
| waveform = waveform.to(device) | |
| # Run diarization | |
| with ProgressHook() as hook: | |
| diary = diarizer({"waveform": waveform, "sample_rate": sample_rate}, hook=hook) | |
| snippets = [] | |
| cache = {} | |
| for turn, _, speaker in diary.itertracks(yield_label=True): | |
| start_ms = int(turn.start * 1000) | |
| end_ms = int(turn.end * 1000) | |
| segment = AudioSegment.from_file(audio_path)[start_ms:end_ms] | |
| # Export and transcribe segment | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| segment.export(tmp.name, format="wav") | |
| try: | |
| segs = model.generate( | |
| input=tmp.name, | |
| cache=cache, | |
| language=language, | |
| use_itn=enable_punct, | |
| batch_size_s=300 | |
| ) | |
| except Exception as e: | |
| cprint(f'Error: {e}', 'red') | |
| segs = None | |
| os.unlink(tmp.name) | |
| # Post-process text | |
| if segs: | |
| txt = rich_transcription_postprocess(segs[0]['text']) | |
| # Remove all punctuation if disabled | |
| if not enable_punct: | |
| txt = re.sub(r"[^\w\s]", "", txt) | |
| if txt: | |
| txt = converter.convert(txt) | |
| tag = f"[{speaker}]" | |
| if enable_punct and not re.search(end_punct_pattern, txt): | |
| txt = f'{txt}。' | |
| else: | |
| txt = f'{txt} ' | |
| if snippets and snippets[-1].startswith(tag): | |
| # Same speaker: merge with previous | |
| prev_text = snippets[-1].split('] ', 1)[1] | |
| snippets[-1] = f"{tag} {prev_text}{txt}" | |
| else: | |
| # New speaker | |
| snippets.append(f"{tag} {txt}") | |
| # Yield accumulated HTML | |
| yield "", format_diarization_html(snippets) | |
| return | |
| def _transcribe_sense_cpu_stream( | |
| model_id: str, | |
| language: str, | |
| audio_path: str, | |
| enable_punct: bool | |
| ): | |
| """ | |
| CPU wrapper for SenseVoiceSmall streaming transcription. | |
| """ | |
| yield from _transcribe_sense_stream_common( | |
| model_id=model_id, | |
| language=language, | |
| audio_path=audio_path, | |
| enable_punct=enable_punct, | |
| backend="cpu", | |
| device=torch.device('cpu'), | |
| banner_text="CPU", | |
| banner_color="red" | |
| ) | |
| def _transcribe_sense_gpu_stream( | |
| model_id: str, | |
| language: str, | |
| audio_path: str, | |
| enable_punct: bool | |
| ): | |
| """ | |
| CUDA wrapper for SenseVoiceSmall streaming transcription. | |
| """ | |
| yield from _transcribe_sense_stream_common( | |
| model_id=model_id, | |
| language=language, | |
| audio_path=audio_path, | |
| enable_punct=enable_punct, | |
| backend="cuda:0", | |
| device=torch.device('cuda'), | |
| banner_text="CUDA", | |
| banner_color="green" | |
| ) | |
| def transcribe_sense_steam(model_id: str, | |
| language: str, | |
| audio_path: str, | |
| enable_punct: bool, | |
| device_sel: str): | |
| if device_sel == "GPU" and torch.cuda.is_available(): | |
| yield from _transcribe_sense_gpu_stream(model_id, language, audio_path, enable_punct) | |
| else: | |
| yield from _transcribe_sense_cpu_stream(model_id, language, audio_path, enable_punct) | |
| # —————— Gradio UI —————— | |
| DEMO_CSS = """ | |
| .diar { | |
| padding: 0.5rem; | |
| color: #f1f1f1; | |
| font-family: monospace; | |
| font-size: 0.9rem; | |
| } | |
| """ | |
| Demo = gr.Blocks(css=DEMO_CSS) | |
| with Demo: | |
| gr.Markdown("## Faster-Whisper vs. SenseVoice") | |
| audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input") | |
| examples = gr.Examples( | |
| examples=[["interview.mp3"], ["news.mp3"], ["meeting.mp3"]], | |
| inputs=[audio_input], | |
| label="Example Audio Files" | |
| ) | |
| # ──────────────────────────────────────────────────────────────── | |
| # 1) CONTROL PANELS (still side-by-side) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Faster-Whisper ASR") | |
| whisper_dd = gr.Dropdown(choices=WHISPER_MODELS, value=WHISPER_MODELS[0], label="Whisper Model") | |
| whisper_lang = gr.Dropdown(choices=WHISPER_LANGUAGES, value="auto", label="Whisper Language") | |
| device_radio = gr.Radio(choices=["GPU","CPU"], value="GPU", label="Device") | |
| whisper_punct_chk = gr.Checkbox(label="Enable Punctuation", value=True) | |
| whisper_multilingual_en = gr.Checkbox(label="Multilingual", value=False) | |
| btn_w = gr.Button("Transcribe with Faster-Whisper") | |
| with gr.Column(): | |
| gr.Markdown("### FunASR SenseVoice ASR") | |
| sense_dd = gr.Dropdown(choices=SENSEVOICE_MODELS, value=SENSEVOICE_MODELS[0], label="SenseVoice Model") | |
| sense_lang = gr.Dropdown(choices=SENSEVOICE_LANGUAGES, value="auto", label="SenseVoice Language") | |
| device_radio_s = gr.Radio(choices=["GPU","CPU"], value="GPU", label="Device") | |
| sense_punct_chk = gr.Checkbox(label="Enable Punctuation", value=True) | |
| btn_s = gr.Button("Transcribe with SenseVoice") | |
| # ──────────────────────────────────────────────────────────────── | |
| # 2) SHARED TRANSCRIPT ROW (aligned side-by-side) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Faster-Whisper Output") | |
| out_w = gr.Textbox(label="Raw Transcript", visible=False) | |
| out_w_d = gr.HTML(label="Diarized Transcript", elem_classes=["diar"]) | |
| with gr.Column(): | |
| gr.Markdown("### SenseVoice Output") | |
| out_s = gr.Textbox(label="Raw Transcript", visible=False) | |
| out_s_d = gr.HTML(label="Diarized Transcript", elem_classes=["diar"]) | |
| # ──────────────────────────────────────────────────────────────── | |
| # 3) WIRING UP TOGGLES & BUTTONS | |
| # wire the callbacks into those shared boxes | |
| btn_w.click( | |
| fn=transcribe_fwhisper_stream, | |
| inputs=[whisper_dd, whisper_lang, audio_input, device_radio, whisper_multilingual_en, whisper_punct_chk], | |
| outputs=[out_w, out_w_d] | |
| ) | |
| btn_s.click( | |
| fn=transcribe_sense_steam, | |
| inputs=[sense_dd, sense_lang, audio_input, sense_punct_chk, device_radio_s], | |
| outputs=[out_s, out_s_d] | |
| ) | |
| if __name__ == "__main__": | |
| Demo.launch() | |