Spaces:
Build error
Build error
| import gradio as gr | |
| import onnxruntime as ort | |
| from transformers import AutoTokenizer | |
| import numpy as np | |
| from gradio_client import Client | |
| # Initialize the context model | |
| context_model_file = "./bart-large-mnli.onnx" | |
| context_session = ort.InferenceSession(context_model_file) | |
| context_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli") | |
| # Initialize the Gradio client for the translation model | |
| translation_client = Client("Frenchizer/space_3") # Replace with your Space name | |
| labels = [ | |
| "aerospace", "anatomy", "anthropology", "art", | |
| "automotive", "blockchain", "biology", "chemistry", | |
| "cryptocurrency", "data science", "design", "e-commerce", | |
| "education", "engineering", "entertainment", "environment", | |
| "fashion", "finance", "food commerce", "general", | |
| "gaming", "healthcare", "history", "html", | |
| "information technology", "IT", "keywords", "legal", | |
| "literature", "machine learning", "marketing", "medicine", | |
| "music", "personal development", "philosophy", "physics", | |
| "politics", "poetry", "programming", "real estate", "retail", | |
| "robotics", "slang", "social media", "speech", "sports", | |
| "sustained", "technical", "theater", "tourism", "travel" | |
| ] | |
| def softmax_with_temperature(logits, temperature=1.0): | |
| exp_logits = np.exp(logits / temperature) | |
| return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) | |
| def detect_context(input_text, top_n=3, score_threshold=0.05): | |
| # Tokenize input text | |
| inputs = context_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=512) | |
| input_ids = inputs["input_ids"].astype(np.int64) | |
| attention_mask = inputs["attention_mask"].astype(np.int64) | |
| # Run inference with the ONNX context model | |
| outputs = context_session.run(None, { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask | |
| }) | |
| scores = outputs[0][0] # Assuming batch size 1; take the first set of logits | |
| # Pair labels with scores | |
| label_scores = [(label, score) for label, score in zip(labels, scores)] | |
| # Sort by scores in descending order | |
| sorted_labels = sorted(label_scores, key=lambda x: x[1], reverse=True) | |
| # Filter by threshold and return top_n contexts | |
| filtered_labels = [label for label, score in sorted_labels if score > score_threshold] | |
| top_contexts = filtered_labels[:top_n] | |
| return top_contexts if top_contexts else ["general"] | |
| def translate_text(input_text): | |
| # Call the translation model via the Gradio client | |
| result = translation_client.predict(input_text) | |
| return result | |
| def process_request(input_text): | |
| # Detect context | |
| context = detect_context(input_text) | |
| print(f"Detected context: {context}") | |
| # Translate text | |
| translation = translate_text(input_text) | |
| return translation | |
| # Create a Gradio interface | |
| interface = gr.Interface( | |
| fn=process_request, | |
| inputs="text", | |
| outputs="text", | |
| title="Frenchizer", | |
| description="Translate text from English to French with context detection." | |
| ) | |
| # Launch the Gradio app | |
| interface.launch() |