import os from typing import Any import gradio as gr import torch from transformers import AutoModelForMaskedLM, AutoTokenizer ADAPTED_MODEL_ID = os.getenv("ADAPTED_MODEL_ID", os.getenv("MODEL_ID", "Rogendo/afribert-kenya-adapted")) BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "castorini/afriberta_large") TOKENIZER_ID = os.getenv("TOKENIZER_ID", "castorini/afriberta_large") HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") def load_models() -> tuple[Any, Any, Any, torch.device]: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, token=HF_TOKEN, use_fast=False) base_model = AutoModelForMaskedLM.from_pretrained(BASE_MODEL_ID, use_safetensors=True) adapted_model = AutoModelForMaskedLM.from_pretrained( ADAPTED_MODEL_ID, token=HF_TOKEN, use_safetensors=True, ) base_model.to(device) adapted_model.to(device) base_model.eval() adapted_model.eval() return tokenizer, base_model, adapted_model, device tokenizer, base_model, adapted_model, device = load_models() MASK_TOKEN = tokenizer.mask_token or "[MASK]" EXAMPLES = [ f"Oya, twendeni zetu, kuna {MASK_TOKEN} flani ameniudhi.", f"Tuma {MASK_TOKEN} kwa kutumia nambari ya simu kupitia huduma ya M-PESA.", f"Mtoto aliripotiwa kwa ofisi ya {MASK_TOKEN} wa jamii baada ya kudhulumiwa nyumbani.", f"Tulifanya meeting jana na manager akasema {MASK_TOKEN} itakuwa ready wiki ijayo.", f"Msee alikuwa poa sana, akanisaidia kupata {MASK_TOKEN} ya ofisi.", ] def normalize_input(text: str) -> str: text = (text or "").strip() if "[MASK]" in text and MASK_TOKEN != "[MASK]": text = text.replace("[MASK]", MASK_TOKEN) return text def model_predictions(model, inputs, mask_positions, top_k: int, model_label: str) -> list[list[Any]]: with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits[0] rows = [] for mask_index, position in enumerate(mask_positions.tolist(), start=1): probabilities = torch.softmax(logits[position], dim=-1) scores, token_ids = torch.topk(probabilities, k=int(top_k)) for rank, (score, token_id) in enumerate(zip(scores, token_ids), start=1): token = tokenizer.decode([token_id.item()]).strip() completed = inputs["input_ids"][0].clone() completed[position] = token_id sequence = tokenizer.decode(completed, skip_special_tokens=True) rows.append([ model_label, mask_index, rank, token, round(float(score.item()), 4), sequence, ]) return rows def predict_masks(text: str, top_k: int) -> tuple[str, list[list[Any]], list[list[Any]]]: text = normalize_input(text) if not text: return "Enter a sentence with a mask token.", [], [] if MASK_TOKEN not in text: return f"Add at least one mask token: `{MASK_TOKEN}`", [], [] inputs = tokenizer(text, return_tensors="pt").to(device) mask_positions = (inputs["input_ids"][0] == tokenizer.mask_token_id).nonzero(as_tuple=True)[0] if len(mask_positions) == 0: return f"No valid mask token found. Use `{MASK_TOKEN}`.", [], [] base_rows = model_predictions(base_model, inputs, mask_positions, top_k, "Base AfriBERT") adapted_rows = model_predictions(adapted_model, inputs, mask_positions, top_k, "Adapted AfriBERT Kenya") comparison_rows = [] for base_row, adapted_row in zip(base_rows, adapted_rows): comparison_rows.append([ base_row[1], base_row[2], base_row[3], base_row[4], adapted_row[3], adapted_row[4], ]) summary = ( f"Base model: `{BASE_MODEL_ID}`\n\n" f"Adapted model: `{ADAPTED_MODEL_ID}`\n\n" f"Tokenizer: `{TOKENIZER_ID}`\n\n" f"Mask token: `{MASK_TOKEN}`\n\n" f"Found {len(mask_positions)} mask position{'s' if len(mask_positions) != 1 else ''}." ) return summary, comparison_rows, base_rows + adapted_rows with gr.Blocks(title="AfriBERT Kenya Masked LM") as demo: gr.Markdown( """ # AfriBERT Kenya Masked Language Modeling Compare base AfriBERT against the Kenya-adapted model on Swahili, Sheng, Kenyan institutional text, M-PESA language, and English-Swahili code-switching. """ ) with gr.Row(): with gr.Column(scale=2): text_input = gr.Textbox( label="Input text", value=EXAMPLES[0], lines=4, placeholder=f"Type a sentence containing {MASK_TOKEN}", ) top_k = gr.Slider( label="Top predictions", minimum=1, maximum=10, value=5, step=1, ) predict_button = gr.Button("Compare masked-token predictions", variant="primary") with gr.Column(scale=1): gr.Markdown( f""" **How to use** Add `{MASK_TOKEN}` where you want the model to predict a token. `[MASK]` is also accepted and converted automatically. For private models, set `HF_TOKEN` before launching the app. The same base AfriBERT tokenizer is used for both models. """ ) summary_output = gr.Markdown() comparison_output = gr.Dataframe( headers=["Mask", "Rank", "Base prediction", "Base score", "Adapted prediction", "Adapted score"], datatype=["number", "number", "str", "number", "str", "number"], label="Side-by-side comparison", wrap=True, ) details_output = gr.Dataframe( headers=["Model", "Mask", "Rank", "Prediction", "Score", "Completed sentence"], datatype=["str", "number", "number", "str", "number", "str"], label="Detailed predictions", wrap=True, ) gr.Examples( examples=EXAMPLES, inputs=text_input, ) predict_button.click( fn=predict_masks, inputs=[text_input, top_k], outputs=[summary_output, comparison_output, details_output], ) text_input.submit( fn=predict_masks, inputs=[text_input, top_k], outputs=[summary_output, comparison_output, details_output], ) if __name__ == "__main__": demo.launch()