Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import onnxruntime as ort | |
| from transformers import MarianTokenizer | |
| import gradio as gr | |
| # Load the tokenizer from the local folder | |
| tokenizer_path = "./tokenizer" # Path to the local tokenizer folder | |
| tokenizer = MarianTokenizer.from_pretrained(tokenizer_path) | |
| # Load the ONNX model | |
| onnx_model_path = "./model.onnx" | |
| session = ort.InferenceSession(onnx_model_path) | |
| def translate(texts, max_length=512): | |
| # Tokenize the input texts | |
| inputs = tokenizer(texts, return_tensors="np", padding=True, truncation=True, max_length=max_length) | |
| input_ids = inputs["input_ids"].astype(np.int64) | |
| attention_mask = inputs["attention_mask"].astype(np.int64) | |
| # Initialize variables for decoding | |
| batch_size = input_ids.shape[0] | |
| decoder_input_ids = np.array([[tokenizer.pad_token_id]] * batch_size, dtype=np.int64) # Start with pad token | |
| eos_reached = np.zeros(batch_size, dtype=bool) # Track which sequences have finished | |
| # Generate output tokens iteratively | |
| for _ in range(max_length): | |
| # Run the ONNX model | |
| onnx_outputs = session.run( | |
| None, | |
| { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "decoder_input_ids": decoder_input_ids, | |
| }, | |
| ) | |
| # Get the next token logits (output of the ONNX model) | |
| next_token_logits = onnx_outputs[0][:, -1, :] # Shape: (batch_size, vocab_size) | |
| # Greedy decoding: select the token with the highest probability | |
| next_tokens = np.argmax(next_token_logits, axis=-1) # Shape: (batch_size,) | |
| # Append the next tokens to the decoder input for the next iteration | |
| decoder_input_ids = np.concatenate([decoder_input_ids, next_tokens[:, None]], axis=-1) | |
| # Check if the EOS token has been generated for each sequence | |
| eos_reached = eos_reached | (next_tokens == tokenizer.eos_token_id) | |
| # Stop if all sequences have reached the EOS token | |
| if all(eos_reached): | |
| break | |
| # Decode the output tokens to text | |
| translations = tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True) | |
| return translations | |
| # Gradio interface | |
| def gradio_translate(input_text): | |
| # Split the input text into lines (assuming one sentence per line) | |
| texts = input_text.strip().split("\n") | |
| translations = translate(texts) | |
| # Join the translations into a single string with line breaks | |
| return "\n".join(translations) | |
| # Create the Gradio interface | |
| interface = gr.Interface( | |
| fn=gradio_translate, | |
| inputs=gr.Textbox(lines=5, placeholder="Enter text to translate...", label="Input Text"), | |
| outputs=gr.Textbox(lines=5, label="Translated Text"), | |
| title="ONNX English to French Translation", | |
| description="Translate English text to French using a MarianMT ONNX model.", | |
| ) | |
| # Launch the Gradio app | |
| interface.launch() |