Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Any | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForMaskedLM, AutoTokenizer | |
| ADAPTED_MODEL_ID = os.getenv("ADAPTED_MODEL_ID", os.getenv("MODEL_ID", "Rogendo/afribert-kenya-adapted")) | |
| BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "castorini/afriberta_large") | |
| TOKENIZER_ID = os.getenv("TOKENIZER_ID", "castorini/afriberta_large") | |
| HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| def load_models() -> tuple[Any, Any, Any, torch.device]: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, token=HF_TOKEN, use_fast=False) | |
| base_model = AutoModelForMaskedLM.from_pretrained(BASE_MODEL_ID, use_safetensors=True) | |
| adapted_model = AutoModelForMaskedLM.from_pretrained( | |
| ADAPTED_MODEL_ID, | |
| token=HF_TOKEN, | |
| use_safetensors=True, | |
| ) | |
| base_model.to(device) | |
| adapted_model.to(device) | |
| base_model.eval() | |
| adapted_model.eval() | |
| return tokenizer, base_model, adapted_model, device | |
| tokenizer, base_model, adapted_model, device = load_models() | |
| MASK_TOKEN = tokenizer.mask_token or "[MASK]" | |
| EXAMPLES = [ | |
| f"Oya, twendeni zetu, kuna {MASK_TOKEN} flani ameniudhi.", | |
| f"Tuma {MASK_TOKEN} kwa kutumia nambari ya simu kupitia huduma ya M-PESA.", | |
| f"Mtoto aliripotiwa kwa ofisi ya {MASK_TOKEN} wa jamii baada ya kudhulumiwa nyumbani.", | |
| f"Tulifanya meeting jana na manager akasema {MASK_TOKEN} itakuwa ready wiki ijayo.", | |
| f"Msee alikuwa poa sana, akanisaidia kupata {MASK_TOKEN} ya ofisi.", | |
| ] | |
| def normalize_input(text: str) -> str: | |
| text = (text or "").strip() | |
| if "[MASK]" in text and MASK_TOKEN != "[MASK]": | |
| text = text.replace("[MASK]", MASK_TOKEN) | |
| return text | |
| def model_predictions(model, inputs, mask_positions, top_k: int, model_label: str) -> list[list[Any]]: | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits[0] | |
| rows = [] | |
| for mask_index, position in enumerate(mask_positions.tolist(), start=1): | |
| probabilities = torch.softmax(logits[position], dim=-1) | |
| scores, token_ids = torch.topk(probabilities, k=int(top_k)) | |
| for rank, (score, token_id) in enumerate(zip(scores, token_ids), start=1): | |
| token = tokenizer.decode([token_id.item()]).strip() | |
| completed = inputs["input_ids"][0].clone() | |
| completed[position] = token_id | |
| sequence = tokenizer.decode(completed, skip_special_tokens=True) | |
| rows.append([ | |
| model_label, | |
| mask_index, | |
| rank, | |
| token, | |
| round(float(score.item()), 4), | |
| sequence, | |
| ]) | |
| return rows | |
| def predict_masks(text: str, top_k: int) -> tuple[str, list[list[Any]], list[list[Any]]]: | |
| text = normalize_input(text) | |
| if not text: | |
| return "Enter a sentence with a mask token.", [], [] | |
| if MASK_TOKEN not in text: | |
| return f"Add at least one mask token: `{MASK_TOKEN}`", [], [] | |
| inputs = tokenizer(text, return_tensors="pt").to(device) | |
| mask_positions = (inputs["input_ids"][0] == tokenizer.mask_token_id).nonzero(as_tuple=True)[0] | |
| if len(mask_positions) == 0: | |
| return f"No valid mask token found. Use `{MASK_TOKEN}`.", [], [] | |
| base_rows = model_predictions(base_model, inputs, mask_positions, top_k, "Base AfriBERT") | |
| adapted_rows = model_predictions(adapted_model, inputs, mask_positions, top_k, "Adapted AfriBERT Kenya") | |
| comparison_rows = [] | |
| for base_row, adapted_row in zip(base_rows, adapted_rows): | |
| comparison_rows.append([ | |
| base_row[1], | |
| base_row[2], | |
| base_row[3], | |
| base_row[4], | |
| adapted_row[3], | |
| adapted_row[4], | |
| ]) | |
| summary = ( | |
| f"Base model: `{BASE_MODEL_ID}`\n\n" | |
| f"Adapted model: `{ADAPTED_MODEL_ID}`\n\n" | |
| f"Tokenizer: `{TOKENIZER_ID}`\n\n" | |
| f"Mask token: `{MASK_TOKEN}`\n\n" | |
| f"Found {len(mask_positions)} mask position{'s' if len(mask_positions) != 1 else ''}." | |
| ) | |
| return summary, comparison_rows, base_rows + adapted_rows | |
| with gr.Blocks(title="AfriBERT Kenya Masked LM") as demo: | |
| gr.Markdown( | |
| """ | |
| # AfriBERT Kenya Masked Language Modeling | |
| Compare base AfriBERT against the Kenya-adapted model on Swahili, Sheng, | |
| Kenyan institutional text, M-PESA language, and English-Swahili code-switching. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| text_input = gr.Textbox( | |
| label="Input text", | |
| value=EXAMPLES[0], | |
| lines=4, | |
| placeholder=f"Type a sentence containing {MASK_TOKEN}", | |
| ) | |
| top_k = gr.Slider( | |
| label="Top predictions", | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| ) | |
| predict_button = gr.Button("Compare masked-token predictions", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| f""" | |
| **How to use** | |
| Add `{MASK_TOKEN}` where you want the model to predict a token. | |
| `[MASK]` is also accepted and converted automatically. | |
| For private models, set `HF_TOKEN` before launching the app. | |
| The same base AfriBERT tokenizer is used for both models. | |
| """ | |
| ) | |
| summary_output = gr.Markdown() | |
| comparison_output = gr.Dataframe( | |
| headers=["Mask", "Rank", "Base prediction", "Base score", "Adapted prediction", "Adapted score"], | |
| datatype=["number", "number", "str", "number", "str", "number"], | |
| label="Side-by-side comparison", | |
| wrap=True, | |
| ) | |
| details_output = gr.Dataframe( | |
| headers=["Model", "Mask", "Rank", "Prediction", "Score", "Completed sentence"], | |
| datatype=["str", "number", "number", "str", "number", "str"], | |
| label="Detailed predictions", | |
| wrap=True, | |
| ) | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=text_input, | |
| ) | |
| predict_button.click( | |
| fn=predict_masks, | |
| inputs=[text_input, top_k], | |
| outputs=[summary_output, comparison_output, details_output], | |
| ) | |
| text_input.submit( | |
| fn=predict_masks, | |
| inputs=[text_input, top_k], | |
| outputs=[summary_output, comparison_output, details_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |