import gradio as gr import onnxruntime as ort from transformers import AutoTokenizer import numpy as np from gradio_client import Client # Initialize the context model context_model_file = "./bart-large-mnli.onnx" context_session = ort.InferenceSession(context_model_file) context_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli") # Initialize the Gradio client for the translation model translation_client = Client("Frenchizer/space_3") # Replace with your Space name labels = [ "aerospace", "anatomy", "anthropology", "art", "automotive", "blockchain", "biology", "chemistry", "cryptocurrency", "data science", "design", "e-commerce", "education", "engineering", "entertainment", "environment", "fashion", "finance", "food commerce", "general", "gaming", "healthcare", "history", "html", "information technology", "IT", "keywords", "legal", "literature", "machine learning", "marketing", "medicine", "music", "personal development", "philosophy", "physics", "politics", "poetry", "programming", "real estate", "retail", "robotics", "slang", "social media", "speech", "sports", "sustained", "technical", "theater", "tourism", "travel" ] def softmax_with_temperature(logits, temperature=1.0): exp_logits = np.exp(logits / temperature) return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) def detect_context(input_text, top_n=3, score_threshold=0.05): # Tokenize input text inputs = context_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=512) input_ids = inputs["input_ids"].astype(np.int64) attention_mask = inputs["attention_mask"].astype(np.int64) # Run inference with the ONNX context model outputs = context_session.run(None, { "input_ids": input_ids, "attention_mask": attention_mask }) scores = outputs[0][0] # Assuming batch size 1; take the first set of logits # Pair labels with scores label_scores = [(label, score) for label, score in zip(labels, scores)] # Sort by scores in descending order sorted_labels = sorted(label_scores, key=lambda x: x[1], reverse=True) # Filter by threshold and return top_n contexts filtered_labels = [label for label, score in sorted_labels if score > score_threshold] top_contexts = filtered_labels[:top_n] return top_contexts if top_contexts else ["general"] def translate_text(input_text): # Call the translation model via the Gradio client result = translation_client.predict(input_text) return result def process_request(input_text): # Detect context context = detect_context(input_text) print(f"Detected context: {context}") # Translate text translation = translate_text(input_text) return translation # Create a Gradio interface interface = gr.Interface( fn=process_request, inputs="text", outputs="text", title="Frenchizer", description="Translate text from English to French with context detection." ) # Launch the Gradio app interface.launch()