Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,32 +14,44 @@ def gradio_predict(input_text):
|
|
| 14 |
tokenized_input = tokenizer(
|
| 15 |
input_text,
|
| 16 |
return_tensors="np",
|
| 17 |
-
padding=
|
| 18 |
-
truncation=True,
|
| 19 |
-
max_length=512
|
| 20 |
)
|
| 21 |
|
| 22 |
# Convert tokenized inputs to numpy arrays and ensure correct shape
|
| 23 |
-
input_ids = np.array(tokenized_input["input_ids"], dtype=np.int64) #
|
| 24 |
-
attention_mask = np.array(tokenized_input["attention_mask"], dtype=np.int64)
|
| 25 |
-
decoder_input_ids = input_ids # Same as input_ids for translation models
|
| 26 |
|
| 27 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
print(f"input_ids shape: {input_ids.shape}, attention_mask shape: {attention_mask.shape}")
|
| 29 |
|
| 30 |
# Perform inference with ONNX model
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
|
| 44 |
# Gradio interface for the web app
|
| 45 |
gr.Interface(
|
|
|
|
| 14 |
tokenized_input = tokenizer(
|
| 15 |
input_text,
|
| 16 |
return_tensors="np",
|
| 17 |
+
padding='max_length', # Pad to max length
|
| 18 |
+
truncation=True, # Truncate if longer than max length
|
| 19 |
+
max_length=512 # Ensure the sequence doesn't exceed the model's max length
|
| 20 |
)
|
| 21 |
|
| 22 |
# Convert tokenized inputs to numpy arrays and ensure correct shape
|
| 23 |
+
input_ids = np.array(tokenized_input["input_ids"], dtype=np.int64) # Shape should be [1, 512]
|
| 24 |
+
attention_mask = np.array(tokenized_input["attention_mask"], dtype=np.int64) # Shape should be [1, 512]
|
|
|
|
| 25 |
|
| 26 |
+
# Prepare decoder input ids if required by your model
|
| 27 |
+
decoder_input_ids = input_ids # Adjust as needed based on model requirements
|
| 28 |
+
|
| 29 |
+
# Debugging output
|
| 30 |
+
print(f"Input Text: {input_text}")
|
| 31 |
+
print(f"Tokens: {tokenizer.tokenize(input_text)}")
|
| 32 |
print(f"input_ids shape: {input_ids.shape}, attention_mask shape: {attention_mask.shape}")
|
| 33 |
|
| 34 |
# Perform inference with ONNX model
|
| 35 |
+
try:
|
| 36 |
+
outputs = session.run(
|
| 37 |
+
None,
|
| 38 |
+
{
|
| 39 |
+
"input_ids": input_ids,
|
| 40 |
+
"attention_mask": attention_mask,
|
| 41 |
+
"decoder_input_ids": decoder_input_ids # Include this only if required by your model
|
| 42 |
+
}
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Debugging output for outputs
|
| 46 |
+
print(f"Outputs: {outputs}, type: {type(outputs)}")
|
| 47 |
+
|
| 48 |
+
# Decode output and return translated text
|
| 49 |
+
translated_text = tokenizer.decode(outputs[0][0], skip_special_tokens=True)
|
| 50 |
+
return translated_text
|
| 51 |
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"Error during inference: {e}")
|
| 54 |
+
return "An error occurred during inference."
|
| 55 |
|
| 56 |
# Gradio interface for the web app
|
| 57 |
gr.Interface(
|