DheivaCodes commited on
Commit
ab39076
Β·
verified Β·
1 Parent(s): 7f39ef1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -10,6 +10,7 @@ from sacrebleu import corpus_bleu
10
  import os
11
  import tempfile
12
 
 
13
  # Load Models
14
  lang_detect_model = AutoModelForSequenceClassification.from_pretrained("papluca/xlm-roberta-base-language-detection")
15
  lang_detect_tokenizer = AutoTokenizer.from_pretrained("papluca/xlm-roberta-base-language-detection")
@@ -49,7 +50,7 @@ dimension = corpus_embeddings.shape[1]
49
  index = faiss.IndexFlatL2(dimension)
50
  index.add(corpus_embeddings)
51
 
52
- # Language Detection
53
  def detect_language(text):
54
  inputs = lang_detect_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
55
  with torch.no_grad():
@@ -58,7 +59,7 @@ def detect_language(text):
58
  pred = torch.argmax(probs, dim=1).item()
59
  return id2lang[pred]
60
 
61
- # Translation
62
  def translate(text, src_code, tgt_code):
63
  trans_tokenizer.src_lang = src_code
64
  encoded = trans_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
@@ -74,8 +75,8 @@ def search_semantic(query, top_k=3):
74
  query_embedding = embed_model.encode([query])
75
  distances, indices = index.search(query_embedding, top_k)
76
  return [(corpus[i], float(distances[0][idx])) for idx, i in enumerate(indices[0])]
77
-
78
- # Save Report
79
  def save_output_to_file(detected_lang, translated, sem_results, bleu_score):
80
  with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f:
81
  f.write(f"Detected Language: {detected_lang}\n")
@@ -87,25 +88,24 @@ def save_output_to_file(detected_lang, translated, sem_results, bleu_score):
87
  f.write(f"\nBLEU Score: {bleu_score}")
88
  return f.name
89
 
90
- # Full Pipeline
91
  def full_pipeline(user_input_text, target_lang_code, human_ref=""):
92
  if not user_input_text.strip():
93
  return "Empty input", "", [], "", "", None
94
 
95
  if len(user_input_text) > 2048:
96
- return "Input too long", "Please enter shorter text (under 2000 characters).", [], "", "", None
97
 
98
  detected_lang = detect_language(user_input_text)
99
  src_nllb = xlm_to_nllb.get(detected_lang, "eng_Latn")
100
 
101
  translated = translate(user_input_text, src_nllb, target_lang_code)
102
  if not translated:
103
- return detected_lang, "Translation failed", [], "", "", None
104
 
105
  sem_results = search_semantic(translated)
106
  result_list = [f"{i+1}. {txt} (Score: {score:.2f})" for i, (txt, score) in enumerate(sem_results)]
107
 
108
- # Plot similarity
109
  labels = [f"{i+1}" for i in range(len(sem_results))]
110
  scores = [score for _, score in sem_results]
111
  plt.figure(figsize=(6, 4))
@@ -128,7 +128,8 @@ def full_pipeline(user_input_text, target_lang_code, human_ref=""):
128
  download_file_path = save_output_to_file(detected_lang, translated, sem_results, bleu_score)
129
  return detected_lang, translated, "\n".join(result_list), plot_path, bleu_score, download_file_path
130
 
131
- # Gradio UI
 
132
  gr.Interface(
133
  fn=full_pipeline,
134
  inputs=[
@@ -142,8 +143,8 @@ gr.Interface(
142
  gr.Textbox(label="Top Semantic Matches"),
143
  gr.Image(label="Semantic Similarity Plot"),
144
  gr.Textbox(label="BLEU Score"),
145
- gr.File(label="Download Translation Report")
146
  ],
147
- title="Multilingual Translator + Semantic Search",
148
  description="Detects language β†’ Translates β†’ Finds related Sanskrit concepts β†’ BLEU optional β†’ Downloadable report."
149
- ).launch()
 
10
  import os
11
  import tempfile
12
 
13
+
14
  # Load Models
15
  lang_detect_model = AutoModelForSequenceClassification.from_pretrained("papluca/xlm-roberta-base-language-detection")
16
  lang_detect_tokenizer = AutoTokenizer.from_pretrained("papluca/xlm-roberta-base-language-detection")
 
50
  index = faiss.IndexFlatL2(dimension)
51
  index.add(corpus_embeddings)
52
 
53
+ # Detect Language
54
  def detect_language(text):
55
  inputs = lang_detect_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
56
  with torch.no_grad():
 
59
  pred = torch.argmax(probs, dim=1).item()
60
  return id2lang[pred]
61
 
62
+ # Translate
63
  def translate(text, src_code, tgt_code):
64
  trans_tokenizer.src_lang = src_code
65
  encoded = trans_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
75
  query_embedding = embed_model.encode([query])
76
  distances, indices = index.search(query_embedding, top_k)
77
  return [(corpus[i], float(distances[0][idx])) for idx, i in enumerate(indices[0])]
78
+
79
+ # Create downloadable output file
80
  def save_output_to_file(detected_lang, translated, sem_results, bleu_score):
81
  with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f:
82
  f.write(f"Detected Language: {detected_lang}\n")
 
88
  f.write(f"\nBLEU Score: {bleu_score}")
89
  return f.name
90
 
 
91
  def full_pipeline(user_input_text, target_lang_code, human_ref=""):
92
  if not user_input_text.strip():
93
  return "Empty input", "", [], "", "", None
94
 
95
  if len(user_input_text) > 2048:
96
+ return " Input too long", "Please enter shorter text (under 2000 characters).", [], "", "", None
97
 
98
  detected_lang = detect_language(user_input_text)
99
  src_nllb = xlm_to_nllb.get(detected_lang, "eng_Latn")
100
 
101
  translated = translate(user_input_text, src_nllb, target_lang_code)
102
  if not translated:
103
+ return detected_lang, " Translation failed", [], "", "", None
104
 
105
  sem_results = search_semantic(translated)
106
  result_list = [f"{i+1}. {txt} (Score: {score:.2f})" for i, (txt, score) in enumerate(sem_results)]
107
 
108
+ # Plot
109
  labels = [f"{i+1}" for i in range(len(sem_results))]
110
  scores = [score for _, score in sem_results]
111
  plt.figure(figsize=(6, 4))
 
128
  download_file_path = save_output_to_file(detected_lang, translated, sem_results, bleu_score)
129
  return detected_lang, translated, "\n".join(result_list), plot_path, bleu_score, download_file_path
130
 
131
+
132
+ # Gradio Interface
133
  gr.Interface(
134
  fn=full_pipeline,
135
  inputs=[
 
143
  gr.Textbox(label="Top Semantic Matches"),
144
  gr.Image(label="Semantic Similarity Plot"),
145
  gr.Textbox(label="BLEU Score"),
146
+ gr.File(label="Download Translation Report") # NEW OUTPUT
147
  ],
148
+ title=" Multilingual Translator + Semantic Search",
149
  description="Detects language β†’ Translates β†’ Finds related Sanskrit concepts β†’ BLEU optional β†’ Downloadable report."
150
+ ).launch()