Frenchizer commited on
Commit
b1da10c
·
verified ·
1 Parent(s): 75c7e8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -18
app.py CHANGED
@@ -14,32 +14,44 @@ def gradio_predict(input_text):
14
  tokenized_input = tokenizer(
15
  input_text,
16
  return_tensors="np",
17
- padding="max_length",
18
- truncation=True, # Ensures input is truncated if it's too long
19
- max_length=512 # Ensure the sequence doesn't exceed the model's max length
20
  )
21
 
22
  # Convert tokenized inputs to numpy arrays and ensure correct shape
23
- input_ids = np.array(tokenized_input["input_ids"], dtype=np.int64) # No need for reshape
24
- attention_mask = np.array(tokenized_input["attention_mask"], dtype=np.int64)
25
- decoder_input_ids = input_ids # Same as input_ids for translation models
26
 
27
- # Check if arrays are correctly formed
 
 
 
 
 
28
  print(f"input_ids shape: {input_ids.shape}, attention_mask shape: {attention_mask.shape}")
29
 
30
  # Perform inference with ONNX model
31
- outputs = session.run(
32
- None,
33
- {
34
- "input_ids": input_ids,
35
- "attention_mask": attention_mask,
36
- "decoder_input_ids": decoder_input_ids
37
- }
38
- )
 
 
 
 
 
 
 
 
39
 
40
- # Decode output and return translated text
41
- translated_text = tokenizer.decode(outputs[0][0], skip_special_tokens=True)
42
- return translated_text
43
 
44
  # Gradio interface for the web app
45
  gr.Interface(
 
14
  tokenized_input = tokenizer(
15
  input_text,
16
  return_tensors="np",
17
+ padding='max_length', # Pad to max length
18
+ truncation=True, # Truncate if longer than max length
19
+ max_length=512 # Ensure the sequence doesn't exceed the model's max length
20
  )
21
 
22
  # Convert tokenized inputs to numpy arrays and ensure correct shape
23
+ input_ids = np.array(tokenized_input["input_ids"], dtype=np.int64) # Shape should be [1, 512]
24
+ attention_mask = np.array(tokenized_input["attention_mask"], dtype=np.int64) # Shape should be [1, 512]
 
25
 
26
+ # Prepare decoder input ids if required by your model
27
+ decoder_input_ids = input_ids # Adjust as needed based on model requirements
28
+
29
+ # Debugging output
30
+ print(f"Input Text: {input_text}")
31
+ print(f"Tokens: {tokenizer.tokenize(input_text)}")
32
  print(f"input_ids shape: {input_ids.shape}, attention_mask shape: {attention_mask.shape}")
33
 
34
  # Perform inference with ONNX model
35
+ try:
36
+ outputs = session.run(
37
+ None,
38
+ {
39
+ "input_ids": input_ids,
40
+ "attention_mask": attention_mask,
41
+ "decoder_input_ids": decoder_input_ids # Include this only if required by your model
42
+ }
43
+ )
44
+
45
+ # Debugging output for outputs
46
+ print(f"Outputs: {outputs}, type: {type(outputs)}")
47
+
48
+ # Decode output and return translated text
49
+ translated_text = tokenizer.decode(outputs[0][0], skip_special_tokens=True)
50
+ return translated_text
51
 
52
+ except Exception as e:
53
+ print(f"Error during inference: {e}")
54
+ return "An error occurred during inference."
55
 
56
  # Gradio interface for the web app
57
  gr.Interface(