from io import BytesIO import os import gradio as gr import spaces from pydub import AudioSegment import json import requests from nemo.collections.asr.models import SortformerEncLabelModel diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2") diar_model.eval() diar_model.sortformer_modules.chunk_len = 340 diar_model.sortformer_modules.chunk_right_context = 40 diar_model.sortformer_modules.fifo_len = 40 diar_model.sortformer_modules.spkcache_update_period = 300 diar_model.sortformer_modules.spkcache_len = 188 diar_model.sortformer_modules._check_streaming_parameters() def preprocess_audio(audio_path): """Convert audio to mono, 16kHz WAV format suitable for pyannote.""" try: if isinstance(audio_path, str): bytes = False else: bytes = True # Load audio with pydub audio = AudioSegment.from_file(BytesIO(audio_path) if bytes else audio_path) # Convert to mono and set sample rate to 16kHz audio = audio.set_channels(1).set_frame_rate(16000) # Export to temporary WAV file temp_wav = "temp_audio.wav" audio.export(temp_wav, format="wav") return temp_wav except Exception as e: raise ValueError(f"Error preprocessing audio: {str(e)}") def handle_audio(url, audio_path): """Handle audio processing and diarization.""" if url: response = requests.get(url, timeout=60) audio_path = response.content audio_path = preprocess_audio(audio_path) res = diarize_audio_diar1(audio_path) # Clean up temporary file if os.path.exists(audio_path): os.remove(audio_path) return json.dumps(res) @spaces.GPU(duration=120) def diarize_audio_diar1(audio_path): """Perform speaker diarization and return formatted results.""" try: predicted_segments = diar_model.diarize(audio=audio_path, batch_size=1) return format_results(predicted_segments[0]) except Exception as e: return f"Error: {str(e)}", "" def format_results(results): """Format results into a readable string.""" if isinstance(results, str): import json results = json.loads(results) if not isinstance(results, list): return [] formatted_results = [] for item in results: if isinstance(item, str): parts = item.strip().split() if len(parts) == 3: formatted_results.append({ "start": float(parts[0]), "end": float(parts[1]), "speaker_id": parts[2] }) elif isinstance(item, dict): formatted_results.append({ "start": item.get("start", 0), "end": item.get("end", 0), "speaker_id": item.get("speaker", item.get("speaker_id", "unknown")) }) formatted_results.sort(key=lambda x: x["start"]) return formatted_results # Gradio interface with gr.Blocks() as demo: gr.Markdown("# Speaker Diarization with nvidia/diar_streaming_sortformer_4spk-v2") gr.Markdown("Upload an audio file and specify the number of speakers to diarize the audio.") with gr.Row(): url_input = gr.Textbox(label="URL") audio_input = gr.Audio(label="Upload Audio File", type="filepath") submit_btn = gr.Button("Diarize") with gr.Row(): json_output = gr.Textbox(label="Diarization Results (JSON)") submit_btn.click( fn=handle_audio, inputs=[url_input, audio_input], outputs=[json_output], concurrency_limit=20, ) # Launch the Gradio app demo.launch()