space_4 / app.py
Frenchizer's picture
Update app.py
2cecc10 verified
import gradio as gr
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np
from gradio_client import Client
# Initialize the context model
context_model_file = "./bart-large-mnli.onnx"
context_session = ort.InferenceSession(context_model_file)
context_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
# Initialize the Gradio client for the translation model
translation_client = Client("Frenchizer/space_3") # Replace with your Space name
labels = [
"aerospace", "anatomy", "anthropology", "art",
"automotive", "blockchain", "biology", "chemistry",
"cryptocurrency", "data science", "design", "e-commerce",
"education", "engineering", "entertainment", "environment",
"fashion", "finance", "food commerce", "general",
"gaming", "healthcare", "history", "html",
"information technology", "IT", "keywords", "legal",
"literature", "machine learning", "marketing", "medicine",
"music", "personal development", "philosophy", "physics",
"politics", "poetry", "programming", "real estate", "retail",
"robotics", "slang", "social media", "speech", "sports",
"sustained", "technical", "theater", "tourism", "travel"
]
def softmax_with_temperature(logits, temperature=1.0):
exp_logits = np.exp(logits / temperature)
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
def detect_context(input_text, top_n=3, score_threshold=0.05):
# Tokenize input text
inputs = context_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=512)
input_ids = inputs["input_ids"].astype(np.int64)
attention_mask = inputs["attention_mask"].astype(np.int64)
# Run inference with the ONNX context model
outputs = context_session.run(None, {
"input_ids": input_ids,
"attention_mask": attention_mask
})
scores = outputs[0][0] # Assuming batch size 1; take the first set of logits
# Pair labels with scores
label_scores = [(label, score) for label, score in zip(labels, scores)]
# Sort by scores in descending order
sorted_labels = sorted(label_scores, key=lambda x: x[1], reverse=True)
# Filter by threshold and return top_n contexts
filtered_labels = [label for label, score in sorted_labels if score > score_threshold]
top_contexts = filtered_labels[:top_n]
return top_contexts if top_contexts else ["general"]
def translate_text(input_text):
# Call the translation model via the Gradio client
result = translation_client.predict(input_text)
return result
def process_request(input_text):
# Detect context
context = detect_context(input_text)
print(f"Detected context: {context}")
# Translate text
translation = translate_text(input_text)
return translation
# Create a Gradio interface
interface = gr.Interface(
fn=process_request,
inputs="text",
outputs="text",
title="Frenchizer",
description="Translate text from English to French with context detection."
)
# Launch the Gradio app
interface.launch()