bitz support
gradio app for masked modeling task
f1ee6d0
Raw
History Blame Contribute Delete
6.55 kB
import os
from typing import Any
import gradio as gr
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
ADAPTED_MODEL_ID = os.getenv("ADAPTED_MODEL_ID", os.getenv("MODEL_ID", "Rogendo/afribert-kenya-adapted"))
BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "castorini/afriberta_large")
TOKENIZER_ID = os.getenv("TOKENIZER_ID", "castorini/afriberta_large")
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
def load_models() -> tuple[Any, Any, Any, torch.device]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, token=HF_TOKEN, use_fast=False)
base_model = AutoModelForMaskedLM.from_pretrained(BASE_MODEL_ID, use_safetensors=True)
adapted_model = AutoModelForMaskedLM.from_pretrained(
ADAPTED_MODEL_ID,
token=HF_TOKEN,
use_safetensors=True,
)
base_model.to(device)
adapted_model.to(device)
base_model.eval()
adapted_model.eval()
return tokenizer, base_model, adapted_model, device
tokenizer, base_model, adapted_model, device = load_models()
MASK_TOKEN = tokenizer.mask_token or "[MASK]"
EXAMPLES = [
f"Oya, twendeni zetu, kuna {MASK_TOKEN} flani ameniudhi.",
f"Tuma {MASK_TOKEN} kwa kutumia nambari ya simu kupitia huduma ya M-PESA.",
f"Mtoto aliripotiwa kwa ofisi ya {MASK_TOKEN} wa jamii baada ya kudhulumiwa nyumbani.",
f"Tulifanya meeting jana na manager akasema {MASK_TOKEN} itakuwa ready wiki ijayo.",
f"Msee alikuwa poa sana, akanisaidia kupata {MASK_TOKEN} ya ofisi.",
]
def normalize_input(text: str) -> str:
text = (text or "").strip()
if "[MASK]" in text and MASK_TOKEN != "[MASK]":
text = text.replace("[MASK]", MASK_TOKEN)
return text
def model_predictions(model, inputs, mask_positions, top_k: int, model_label: str) -> list[list[Any]]:
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits[0]
rows = []
for mask_index, position in enumerate(mask_positions.tolist(), start=1):
probabilities = torch.softmax(logits[position], dim=-1)
scores, token_ids = torch.topk(probabilities, k=int(top_k))
for rank, (score, token_id) in enumerate(zip(scores, token_ids), start=1):
token = tokenizer.decode([token_id.item()]).strip()
completed = inputs["input_ids"][0].clone()
completed[position] = token_id
sequence = tokenizer.decode(completed, skip_special_tokens=True)
rows.append([
model_label,
mask_index,
rank,
token,
round(float(score.item()), 4),
sequence,
])
return rows
def predict_masks(text: str, top_k: int) -> tuple[str, list[list[Any]], list[list[Any]]]:
text = normalize_input(text)
if not text:
return "Enter a sentence with a mask token.", [], []
if MASK_TOKEN not in text:
return f"Add at least one mask token: `{MASK_TOKEN}`", [], []
inputs = tokenizer(text, return_tensors="pt").to(device)
mask_positions = (inputs["input_ids"][0] == tokenizer.mask_token_id).nonzero(as_tuple=True)[0]
if len(mask_positions) == 0:
return f"No valid mask token found. Use `{MASK_TOKEN}`.", [], []
base_rows = model_predictions(base_model, inputs, mask_positions, top_k, "Base AfriBERT")
adapted_rows = model_predictions(adapted_model, inputs, mask_positions, top_k, "Adapted AfriBERT Kenya")
comparison_rows = []
for base_row, adapted_row in zip(base_rows, adapted_rows):
comparison_rows.append([
base_row[1],
base_row[2],
base_row[3],
base_row[4],
adapted_row[3],
adapted_row[4],
])
summary = (
f"Base model: `{BASE_MODEL_ID}`\n\n"
f"Adapted model: `{ADAPTED_MODEL_ID}`\n\n"
f"Tokenizer: `{TOKENIZER_ID}`\n\n"
f"Mask token: `{MASK_TOKEN}`\n\n"
f"Found {len(mask_positions)} mask position{'s' if len(mask_positions) != 1 else ''}."
)
return summary, comparison_rows, base_rows + adapted_rows
with gr.Blocks(title="AfriBERT Kenya Masked LM") as demo:
gr.Markdown(
"""
# AfriBERT Kenya Masked Language Modeling
Compare base AfriBERT against the Kenya-adapted model on Swahili, Sheng,
Kenyan institutional text, M-PESA language, and English-Swahili code-switching.
"""
)
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Input text",
value=EXAMPLES[0],
lines=4,
placeholder=f"Type a sentence containing {MASK_TOKEN}",
)
top_k = gr.Slider(
label="Top predictions",
minimum=1,
maximum=10,
value=5,
step=1,
)
predict_button = gr.Button("Compare masked-token predictions", variant="primary")
with gr.Column(scale=1):
gr.Markdown(
f"""
**How to use**
Add `{MASK_TOKEN}` where you want the model to predict a token.
`[MASK]` is also accepted and converted automatically.
For private models, set `HF_TOKEN` before launching the app.
The same base AfriBERT tokenizer is used for both models.
"""
)
summary_output = gr.Markdown()
comparison_output = gr.Dataframe(
headers=["Mask", "Rank", "Base prediction", "Base score", "Adapted prediction", "Adapted score"],
datatype=["number", "number", "str", "number", "str", "number"],
label="Side-by-side comparison",
wrap=True,
)
details_output = gr.Dataframe(
headers=["Model", "Mask", "Rank", "Prediction", "Score", "Completed sentence"],
datatype=["str", "number", "number", "str", "number", "str"],
label="Detailed predictions",
wrap=True,
)
gr.Examples(
examples=EXAMPLES,
inputs=text_input,
)
predict_button.click(
fn=predict_masks,
inputs=[text_input, top_k],
outputs=[summary_output, comparison_output, details_output],
)
text_input.submit(
fn=predict_masks,
inputs=[text_input, top_k],
outputs=[summary_output, comparison_output, details_output],
)
if __name__ == "__main__":
demo.launch()