import gradio as gr import onnxruntime as ort from transformers import AutoTokenizer import numpy as np MODEL_FILE = "./model.onnx" session = ort.InferenceSession(MODEL_FILE) tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr") def gradio_predict(input_text): try: # Tokenize input text tokenized_input = tokenizer( input_text, return_tensors="np", padding=True, truncation=True, max_length=512 ) # Prepare inputs input_ids = tokenized_input["input_ids"].astype(np.int64) attention_mask = tokenized_input["attention_mask"].astype(np.int64) # Initialize decoder input with start token decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id # Use cls or pad as a fallback decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64) # Iterative decoding loop max_decoder_length = 512 # Adjust as needed for _ in range(max_decoder_length): # Perform inference outputs = session.run( None, { "input_ids": input_ids, "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, } ) # Get logits and predicted token logits = outputs[0] next_token_id = np.argmax(logits[:, -1, :], axis=-1).item() # Append the predicted token to decoder input decoder_input_ids = np.concatenate( [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1 ) # Stop if end-of-sequence token is generated if next_token_id == tokenizer.eos_token_id: break # Decode the sequence translated_text = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True) return translated_text except Exception as e: print(f"Detailed error: {str(e)}") import traceback print(traceback.format_exc()) return f"Error during translation: {str(e)}" # Gradio interface for the web app gr.Interface( fn=gradio_predict, inputs="text", outputs="text", live=True ).launch()