Frenchizer commited on
Commit
a4ea3eb
·
verified ·
1 Parent(s): ce28322

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -48
app.py CHANGED
@@ -1,12 +1,3 @@
1
- import gradio as gr
2
- import onnxruntime as ort
3
- from transformers import AutoTokenizer
4
- import numpy as np
5
-
6
- MODEL_FILE = "./model.onnx"
7
- session = ort.InferenceSession(MODEL_FILE)
8
- tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
9
-
10
  def gradio_predict(input_text):
11
  try:
12
  # Tokenize input text
@@ -17,53 +8,47 @@ def gradio_predict(input_text):
17
  truncation=True,
18
  max_length=512
19
  )
20
-
21
  # Prepare inputs
22
  input_ids = tokenized_input["input_ids"].astype(np.int64)
23
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
24
-
25
- # Use a specific token ID for decoder start (for Helsinki-NLP models)
26
- decoder_start_token_id = 59513 # This is the typical start token for Helsinki-NLP models
27
  decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
28
-
29
- print("Initial shapes:")
30
- print(f"input_ids shape: {input_ids.shape}")
31
- print(f"attention_mask shape: {attention_mask.shape}")
32
- print(f"decoder_input_ids shape: {decoder_input_ids.shape}")
33
-
34
- # Run initial inference
35
- outputs = session.run(
36
- None,
37
- {
38
- "input_ids": input_ids,
39
- "attention_mask": attention_mask,
40
- "decoder_input_ids": decoder_input_ids
41
- }
42
- )
43
-
44
- # Get logits and convert to token ids
45
- logits = outputs[0]
46
- token_ids = np.argmax(logits[0], axis=-1)
47
-
48
- # Find end of sequence (using pad token since eos might also be None)
49
- end_idx = np.where(token_ids == tokenizer.pad_token_id)[0]
50
- if len(end_idx) > 0:
51
- token_ids = token_ids[:end_idx[0]]
52
-
 
 
53
  # Decode the sequence
54
- translated_text = tokenizer.decode(token_ids, skip_special_tokens=True)
55
  return translated_text
56
-
57
  except Exception as e:
58
  print(f"Detailed error: {str(e)}")
59
  import traceback
60
  print(traceback.format_exc())
61
  return f"Error during translation: {str(e)}"
62
-
63
- # Gradio interface for the web app
64
- gr.Interface(
65
- fn=gradio_predict,
66
- inputs="text",
67
- outputs="text",
68
- live=True
69
- ).launch()
 
 
 
 
 
 
 
 
 
 
1
  def gradio_predict(input_text):
2
  try:
3
  # Tokenize input text
 
8
  truncation=True,
9
  max_length=512
10
  )
11
+
12
  # Prepare inputs
13
  input_ids = tokenized_input["input_ids"].astype(np.int64)
14
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
15
+
16
+ # Initialize decoder input with start token
17
+ decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id # Use cls or pad as a fallback
18
  decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
19
+
20
+ # Iterative decoding loop
21
+ max_decoder_length = 512 # Adjust as needed
22
+ for _ in range(max_decoder_length):
23
+ # Perform inference
24
+ outputs = session.run(
25
+ None,
26
+ {
27
+ "input_ids": input_ids,
28
+ "attention_mask": attention_mask,
29
+ "decoder_input_ids": decoder_input_ids,
30
+ }
31
+ )
32
+
33
+ # Get logits and predicted token
34
+ logits = outputs[0]
35
+ next_token_id = np.argmax(logits[:, -1, :], axis=-1).item()
36
+
37
+ # Append the predicted token to decoder input
38
+ decoder_input_ids = np.concatenate(
39
+ [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
40
+ )
41
+
42
+ # Stop if end-of-sequence token is generated
43
+ if next_token_id == tokenizer.eos_token_id:
44
+ break
45
+
46
  # Decode the sequence
47
+ translated_text = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
48
  return translated_text
49
+
50
  except Exception as e:
51
  print(f"Detailed error: {str(e)}")
52
  import traceback
53
  print(traceback.format_exc())
54
  return f"Error during translation: {str(e)}"