Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import os | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import spaces | |
| import nemo.collections.asr as nemo_asr | |
| LANGUAGE_NAME_TO_CODE = { | |
| "Assamese": "as", | |
| "Bengali": "bn", | |
| "Bodo": "br", | |
| "Dogri": "doi", | |
| "Gujarati": "gu", | |
| "Hindi": "hi", | |
| "Kannada": "kn", | |
| "Kashmiri": "ks", | |
| "Konkani": "kok", | |
| "Maithili": "mai", | |
| "Malayalam": "ml", | |
| "Manipuri": "mni", | |
| "Marathi": "mr", | |
| "Nepali": "ne", | |
| "Odia": "or", | |
| "Punjabi": "pa", | |
| "Sanskrit": "sa", | |
| "Santali": "sat", | |
| "Sindhi": "sd", | |
| "Tamil": "ta", | |
| "Telugu": "te", | |
| "Urdu": "ur" | |
| } | |
| DESCRIPTION = """\ | |
| ### **IndicConformer: Speech Recognition for Indian Languages** 🎙️➡️📜 | |
| This Gradio demo showcases **IndicConformer**, a speech recognition model for **22 Indian languages**. The model operates in two modes: **CTC (Connectionist Temporal Classification)** and **RNNT (Recurrent Neural Network Transducer)**, providing robust and accurate transcriptions across diverse linguistic and acoustic conditions. | |
| #### **How to Use:** | |
| 1. **Upload or record** an audio clip in any supported Indian language. | |
| 2. Select the **mode** (CTC or RNNT) for transcription. | |
| 3. Click **"Transcribe"** to generate the corresponding text in the target language. | |
| 4. View or copy the output for further use. | |
| 🚀 Try it out and experience seamless speech recognition for Indian languages! | |
| """ | |
| hf_token = os.getenv("HF_TOKEN") | |
| device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
| torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32 | |
| model_name_or_path = "ai4bharat/IndicConformer" | |
| model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name_or_path).to(device) | |
| # model = nemo_asr.models.EncDecCTCModel.restore_from("indicconformer_stt_bn_hybrid_rnnt_large.nemo").to(device) | |
| model.eval() | |
| CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available() | |
| AUDIO_SAMPLE_RATE = 16000 | |
| MAX_INPUT_AUDIO_LENGTH = 60 # in seconds | |
| DEFAULT_TARGET_LANGUAGE = "Bengali" | |
| def run_asr_ctc(input_audio: str, target_language: str) -> str: | |
| lang_id = LANGUAGE_NAME_TO_CODE[target_language] | |
| # Load and preprocess audio | |
| audio_tensor, orig_freq = torchaudio.load(input_audio) | |
| # Convert to mono if not already | |
| if audio_tensor.shape[0] > 1: | |
| audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True) | |
| # Ensure shape [B x T] | |
| if len(audio_tensor.shape) == 1: | |
| audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension if missing | |
| if audio_tensor.ndim > 1: | |
| audio_tensor = audio_tensor.squeeze(0) | |
| # Resample to 16kHz | |
| audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=orig_freq, new_freq=16000) | |
| model.cur_decoder = "ctc" | |
| ctc_text = model.transcribe([audio_tensor.numpy()], batch_size=1, logprobs=False, language_id=lang_id)[0] | |
| return ctc_text[0] | |
| # @spaces.GPU | |
| # def run_asr_ctc(input_audio: str, target_language: str) -> str: | |
| # # preprocess_audio(input_audio) | |
| # # input_audio, orig_freq = torchaudio.load(input_audio) | |
| # # input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000) | |
| # lang_id = LANGUAGE_NAME_TO_CODE[target_language] | |
| # model.cur_decoder = "ctc" | |
| # ctc_text = model.transcribe([input_audio], batch_size=1, logprobs=False, language_id=lang_id)[0] | |
| # return ctc_text[0] | |
| def run_asr_rnnt(input_audio: str, target_language: str) -> str: | |
| lang_id = LANGUAGE_NAME_TO_CODE[target_language] | |
| # Load and preprocess audio | |
| audio_tensor, orig_freq = torchaudio.load(input_audio) | |
| # Convert to mono if not already | |
| if audio_tensor.shape[0] > 1: | |
| audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True) | |
| # Ensure shape [B x T] | |
| if len(audio_tensor.shape) == 1: | |
| audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension if missing | |
| if audio_tensor.ndim > 1: | |
| audio_tensor = audio_tensor.squeeze(0) | |
| # Resample to 16kHz | |
| audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=orig_freq, new_freq=16000) | |
| model.cur_decoder = "rnnt" | |
| ctc_text = model.transcribe([audio_tensor.numpy()], batch_size=1, logprobs=False, language_id=lang_id)[0] | |
| return ctc_text[0] | |
| # @spaces.GPU | |
| # def run_asr_rnnt(input_audio: str, target_language: str) -> str: | |
| # # preprocess_audio(input_audio) | |
| # # input_audio, orig_freq = torchaudio.load(input_audio) | |
| # # input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000) | |
| # lang_id = LANGUAGE_NAME_TO_CODE[target_language] | |
| # model.cur_decoder = "rnnt" | |
| # ctc_text = model.transcribe([input_audio], batch_size=1,logprobs=False, language_id=lang_id)[0] | |
| # return ctc_text[0] | |
| with gr.Blocks() as demo_asr_ctc: | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| input_audio = gr.Audio(label="Input speech", type="filepath") | |
| target_language = gr.Dropdown( | |
| label="Target language", | |
| choices=LANGUAGE_NAME_TO_CODE.keys(), | |
| value=DEFAULT_TARGET_LANGUAGE, | |
| ) | |
| btn = gr.Button("Transcribe") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Transcribed text") | |
| gr.Examples( | |
| examples=[ | |
| ["assets/Bengali.wav", "Bengali", "English"], | |
| ["assets/Gujarati.wav", "Gujarati", "Hindi"], | |
| ["assets/Punjabi.wav", "Punjabi", "Hindi"], | |
| ], | |
| inputs=[input_audio, target_language], | |
| outputs=output_text, | |
| fn=run_asr_ctc, | |
| cache_examples=CACHE_EXAMPLES, | |
| api_name=False, | |
| ) | |
| btn.click( | |
| fn=run_asr_ctc, | |
| inputs=[input_audio, target_language], | |
| outputs=output_text, | |
| api_name="asr", | |
| ) | |
| with gr.Blocks() as demo_asr_rnnt: | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| input_audio = gr.Audio(label="Input speech", type="filepath") | |
| target_language = gr.Dropdown( | |
| label="Target language", | |
| choices=LANGUAGE_NAME_TO_CODE.keys(), | |
| value=DEFAULT_TARGET_LANGUAGE, | |
| ) | |
| btn = gr.Button("Transcribe") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Transcribed text") | |
| gr.Examples( | |
| examples=[ | |
| ["assets/Bengali.wav", "Bengali", "English"], | |
| ["assets/Gujarati.wav", "Gujarati", "Hindi"], | |
| ["assets/Punjabi.wav", "Punjabi", "Hindi"], | |
| ], | |
| inputs=[input_audio, target_language], | |
| outputs=output_text, | |
| fn=run_asr_rnnt, | |
| cache_examples=CACHE_EXAMPLES, | |
| api_name=False, | |
| ) | |
| btn.click( | |
| fn=run_asr_rnnt, | |
| inputs=[input_audio, target_language], | |
| outputs=output_text, | |
| api_name="asr", | |
| ) | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| gr.DuplicateButton( | |
| value="Duplicate Space for private use", | |
| elem_id="duplicate-button", | |
| visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab(label="CTC"): | |
| demo_asr_ctc.render() | |
| with gr.Tab(label="RNNT"): | |
| demo_asr_rnnt.render() | |
| if __name__ == "__main__": | |
| demo.queue(max_size=50).launch() | |