File size: 2,325 Bytes
52a30dd
 
 
 
 
 
 
 
 
c84f925
10fd17d
 
 
 
 
ea68300
10fd17d
 
 
a4ea3eb
76720d2
10fd17d
 
a4ea3eb
 
 
ce28322
a4ea3eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fd7731
a4ea3eb
b1da10c
a4ea3eb
b1da10c
76720d2
6fd7731
 
10fd17d
52a30dd
 
 
 
 
 
e77bb3f
8cfd943
52a30dd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np

MODEL_FILE = "./model.onnx"
session = ort.InferenceSession(MODEL_FILE)
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")

def gradio_predict(input_text):
    try:
        # Tokenize input text
        tokenized_input = tokenizer(
            input_text,
            return_tensors="np",
            padding=True,
            truncation=True,
            max_length=512
        )

        # Prepare inputs
        input_ids = tokenized_input["input_ids"].astype(np.int64)
        attention_mask = tokenized_input["attention_mask"].astype(np.int64)

        # Initialize decoder input with start token
        decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id  # Use cls or pad as a fallback
        decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)

        # Iterative decoding loop
        max_decoder_length = 512  # Adjust as needed
        for _ in range(max_decoder_length):
            # Perform inference
            outputs = session.run(
                None,
                {
                    "input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "decoder_input_ids": decoder_input_ids,
                }
            )

            # Get logits and predicted token
            logits = outputs[0]
            next_token_id = np.argmax(logits[:, -1, :], axis=-1).item()

            # Append the predicted token to decoder input
            decoder_input_ids = np.concatenate(
                [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
            )

            # Stop if end-of-sequence token is generated
            if next_token_id == tokenizer.eos_token_id:
                break

        # Decode the sequence
        translated_text = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
        return translated_text

    except Exception as e:
        print(f"Detailed error: {str(e)}")
        import traceback
        print(traceback.format_exc())
        return f"Error during translation: {str(e)}"

# Gradio interface for the web app
gr.Interface(
    fn=gradio_predict,
    inputs="text",
    outputs="text",
    live=True
).launch()