Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import onnxruntime as ort | |
| from transformers import AutoTokenizer | |
| import numpy as np | |
| MODEL_FILE = "./model.onnx" | |
| session = ort.InferenceSession(MODEL_FILE) | |
| tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr") | |
| def gradio_predict(input_text): | |
| try: | |
| # Tokenize input text | |
| tokenized_input = tokenizer( | |
| input_text, | |
| return_tensors="np", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| # Prepare inputs | |
| input_ids = tokenized_input["input_ids"].astype(np.int64) | |
| attention_mask = tokenized_input["attention_mask"].astype(np.int64) | |
| # Initialize decoder input with start token | |
| decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id # Use cls or pad as a fallback | |
| decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64) | |
| # Iterative decoding loop | |
| max_decoder_length = 512 # Adjust as needed | |
| for _ in range(max_decoder_length): | |
| # Perform inference | |
| outputs = session.run( | |
| None, | |
| { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "decoder_input_ids": decoder_input_ids, | |
| } | |
| ) | |
| # Get logits and predicted token | |
| logits = outputs[0] | |
| next_token_id = np.argmax(logits[:, -1, :], axis=-1).item() | |
| # Append the predicted token to decoder input | |
| decoder_input_ids = np.concatenate( | |
| [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1 | |
| ) | |
| # Stop if end-of-sequence token is generated | |
| if next_token_id == tokenizer.eos_token_id: | |
| break | |
| # Decode the sequence | |
| translated_text = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True) | |
| return translated_text | |
| except Exception as e: | |
| print(f"Detailed error: {str(e)}") | |
| import traceback | |
| print(traceback.format_exc()) | |
| return f"Error during translation: {str(e)}" | |
| # Gradio interface for the web app | |
| gr.Interface( | |
| fn=gradio_predict, | |
| inputs="text", | |
| outputs="text", | |
| live=True | |
| ).launch() | |