Spaces:
Sleeping
Sleeping
File size: 2,325 Bytes
52a30dd c84f925 10fd17d ea68300 10fd17d a4ea3eb 76720d2 10fd17d a4ea3eb ce28322 a4ea3eb 6fd7731 a4ea3eb b1da10c a4ea3eb b1da10c 76720d2 6fd7731 10fd17d 52a30dd e77bb3f 8cfd943 52a30dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import gradio as gr
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np
MODEL_FILE = "./model.onnx"
session = ort.InferenceSession(MODEL_FILE)
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
def gradio_predict(input_text):
try:
# Tokenize input text
tokenized_input = tokenizer(
input_text,
return_tensors="np",
padding=True,
truncation=True,
max_length=512
)
# Prepare inputs
input_ids = tokenized_input["input_ids"].astype(np.int64)
attention_mask = tokenized_input["attention_mask"].astype(np.int64)
# Initialize decoder input with start token
decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id # Use cls or pad as a fallback
decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
# Iterative decoding loop
max_decoder_length = 512 # Adjust as needed
for _ in range(max_decoder_length):
# Perform inference
outputs = session.run(
None,
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
}
)
# Get logits and predicted token
logits = outputs[0]
next_token_id = np.argmax(logits[:, -1, :], axis=-1).item()
# Append the predicted token to decoder input
decoder_input_ids = np.concatenate(
[decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
)
# Stop if end-of-sequence token is generated
if next_token_id == tokenizer.eos_token_id:
break
# Decode the sequence
translated_text = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
return translated_text
except Exception as e:
print(f"Detailed error: {str(e)}")
import traceback
print(traceback.format_exc())
return f"Error during translation: {str(e)}"
# Gradio interface for the web app
gr.Interface(
fn=gradio_predict,
inputs="text",
outputs="text",
live=True
).launch()
|