Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig, AutoTokenizer | |
| from IndicTransToolkit import IndicProcessor | |
| import speech_recognition as sr | |
| # Constants | |
| BATCH_SIZE = 4 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| quantization = None | |
| # ---- IndicTrans2 Model Initialization ---- | |
| def initialize_model_and_tokenizer(ckpt_dir, quantization): | |
| if quantization == "4-bit": | |
| qconfig = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| elif quantization == "8-bit": | |
| qconfig = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| bnb_8bit_use_double_quant=True, | |
| bnb_8bit_compute_dtype=torch.bfloat16, | |
| ) | |
| else: | |
| qconfig = None | |
| tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True) | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| ckpt_dir, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| quantization_config=qconfig, | |
| ) | |
| if qconfig is None: | |
| model = model.to(DEVICE) | |
| if DEVICE == "cuda": | |
| model.half() | |
| model.eval() | |
| return tokenizer, model | |
| def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip): | |
| translations = [] | |
| for i in range(0, len(input_sentences), BATCH_SIZE): | |
| batch = input_sentences[i : i + BATCH_SIZE] | |
| batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang) | |
| inputs = tokenizer( | |
| batch, | |
| truncation=True, | |
| padding="longest", | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| generated_tokens = model.generate( | |
| **inputs, | |
| use_cache=True, | |
| min_length=0, | |
| max_length=256, | |
| num_beams=5, | |
| num_return_sequences=1, | |
| ) | |
| with tokenizer.as_target_tokenizer(): | |
| generated_tokens = tokenizer.batch_decode( | |
| generated_tokens.detach().cpu().tolist(), | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang) | |
| del inputs | |
| torch.cuda.empty_cache() | |
| return translations | |
| # Initialize IndicTrans2 | |
| en_indic_ckpt_dir = "ai4bharat/indictrans2-indic-en-1B" | |
| en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, quantization) | |
| ip = IndicProcessor(inference=True) | |
| # ---- Gradio Function ---- | |
| def transcribe_and_translate(audio): | |
| recognizer = sr.Recognizer() | |
| with sr.AudioFile(audio) as source: | |
| audio_data = recognizer.record(source) | |
| try: | |
| # Malayalam transcription using Google API | |
| malayalam_text = recognizer.recognize_google(audio_data, language="ml-IN") | |
| except sr.UnknownValueError: | |
| return "Could not understand audio", "" | |
| except sr.RequestError as e: | |
| return f"Google API Error: {e}", "" | |
| # Translation | |
| en_sents = [malayalam_text] | |
| src_lang, tgt_lang = "mal_Mlym", "eng_Latn" | |
| translations = batch_translate(en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip) | |
| return malayalam_text, translations[0] | |
| # ---- Gradio Interface ---- | |
| iface = gr.Interface( | |
| fn=transcribe_and_translate, | |
| inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"), | |
| outputs=[ | |
| gr.Textbox(label="Malayalam Transcription"), | |
| gr.Textbox(label="English Translation") | |
| ], | |
| title="Malayalam Speech Recognition & Translation", | |
| description="Speak in Malayalam β Transcribe using Google Speech Recognition β Translate to English using IndicTrans2." | |
| ) | |
| iface.launch(debug=True, share=True) | |