File size: 3,080 Bytes
96a8ebf
7a24f06
 
96a8ebf
 
7a24f06
96a8ebf
 
 
 
7a24f06
96a8ebf
2cecc10
96a8ebf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a24f06
96a8ebf
7a24f06
96a8ebf
 
7a24f06
96a8ebf
 
 
 
 
7a24f06
96a8ebf
7a24f06
96a8ebf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a24f06
 
96a8ebf
 
 
 
 
 
 
 
 
7a24f06
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np
from gradio_client import Client

# Initialize the context model
context_model_file = "./bart-large-mnli.onnx"
context_session = ort.InferenceSession(context_model_file)
context_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")

# Initialize the Gradio client for the translation model
translation_client = Client("Frenchizer/space_3")  # Replace with your Space name

labels = [
    "aerospace", "anatomy", "anthropology", "art", 
    "automotive", "blockchain", "biology", "chemistry", 
    "cryptocurrency", "data science", "design", "e-commerce",
    "education", "engineering", "entertainment", "environment",
    "fashion", "finance", "food commerce", "general",
    "gaming", "healthcare", "history", "html",
    "information technology", "IT", "keywords", "legal",
    "literature", "machine learning", "marketing", "medicine",
    "music", "personal development", "philosophy", "physics",
    "politics", "poetry", "programming", "real estate", "retail", 
    "robotics", "slang", "social media", "speech", "sports",
    "sustained", "technical", "theater", "tourism", "travel"
]

def softmax_with_temperature(logits, temperature=1.0):
    exp_logits = np.exp(logits / temperature)
    return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)

def detect_context(input_text, top_n=3, score_threshold=0.05):
    # Tokenize input text
    inputs = context_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=512)
    
    input_ids = inputs["input_ids"].astype(np.int64)
    attention_mask = inputs["attention_mask"].astype(np.int64)

    # Run inference with the ONNX context model
    outputs = context_session.run(None, {
        "input_ids": input_ids,
        "attention_mask": attention_mask
    })

    scores = outputs[0][0]  # Assuming batch size 1; take the first set of logits

    # Pair labels with scores
    label_scores = [(label, score) for label, score in zip(labels, scores)]

    # Sort by scores in descending order
    sorted_labels = sorted(label_scores, key=lambda x: x[1], reverse=True)

    # Filter by threshold and return top_n contexts
    filtered_labels = [label for label, score in sorted_labels if score > score_threshold]
    top_contexts = filtered_labels[:top_n]

    return top_contexts if top_contexts else ["general"]

def translate_text(input_text):
    # Call the translation model via the Gradio client
    result = translation_client.predict(input_text)
    return result

def process_request(input_text):
    # Detect context
    context = detect_context(input_text)
    print(f"Detected context: {context}")

    # Translate text
    translation = translate_text(input_text)
    return translation

# Create a Gradio interface
interface = gr.Interface(
    fn=process_request,
    inputs="text",
    outputs="text",
    title="Frenchizer",
    description="Translate text from English to French with context detection."
)

# Launch the Gradio app
interface.launch()