MSA-S2S-lemmatizer
Model Description
The model is developed for Arabic lemmatization, focusing on Modern Standard Arabic (MSA). It follows a sequence-to-sequence formulation of lemmatization, where the model generates the lemma of a given word knowing 2 words before and 2 words after the current word rather than treating lemmas as fixed classification labels.
The model is evaluated using lemma accuracy as the main metric, with an additional normalized lemma accuracy metric that accounts for orthographic and diacritic variation. The full methodology, training setup, hyperparameters, and evaluation results are described in our paper โLemmatization as a Classification Task: Results from Arabic across Multiple Genresโ
Standalone Usage
The model can also be used independently without the full lemmatization workflow on the GitHub repository (https://github.com/CAMeL-Lab/lemmatization-as-classification). In this case, the input should contain the target word surrounded by the special token <target>, with up to two words before and two words after the target word.
import re
import math
import pandas as pd
import torch
from tqdm import tqdm
from tqdm.auto import tqdm
tqdm.pandas()
from transformers import T5Tokenizer, T5ForConditionalGeneration
DIALECT_MODELS = {
"msa": "CAMeL-Lab/MSA-S2S-lemmatizer",
}
def load_model(s2s_dialect: str):
model_name = DIALECT_MODELS[s2s_dialect]
tokenizer = T5Tokenizer.from_pretrained(model_name, use_fast=True, legacy=False)
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer.add_special_tokens({"additional_special_tokens": ["<target>"]})
model.resize_token_embeddings(len(tokenizer))
return tokenizer, model
def predict(tokenizer, model, texts: list[str], device=None, batch_size: int = 16) -> list[str]:
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
all_preds = []
total_batches = math.ceil(len(texts) / batch_size)
for i in tqdm(range(0, len(texts), batch_size), total=total_batches, desc="Predicting"):
batch = texts[i:i + batch_size]
enc = tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=64
)
enc = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
out = model.generate(
**enc,
max_length=50,
num_beams=1,
do_sample=False
)
all_preds.extend(tokenizer.batch_decode(out, skip_special_tokens=True))
return all_preds
def get_context_window_fast(sentence_index, word_index, window_size=2):
words, indices = sentence_lookup[sentence_index]
target_pos = indices.index(word_index)
start_idx = max(0, target_pos - window_size)
end_idx = min(len(words), target_pos + window_size + 1)
context_words = words[start_idx:end_idx][:]
target_word_idx = target_pos - start_idx
context_words[target_word_idx] = f"<target>{context_words[target_word_idx]}<target>"
return f"lemmatize: {' '.join(context_words)}"
# df should contain an input_text column with the target word marked using <target>
# Example input: "ุฃูุง ุนุงูุฒ <target>ุฃุฑูุญ<target> ุงูุจูุช ุฏูููุชู"
# Sort df by sentence_index and word_index
df = df.sort_values(by=["sentence_index", "word_index"])
# Build a lookup dict: {sentence_index: (words_list, indices_list)}
sentence_lookup = {
sid: (group['word'].astype(str).tolist(), group['word_index'].tolist())
for sid, group in df.sort_values('word_index').groupby('sentence_index')
}
df['input_text'] = df.progress_apply(
lambda row: get_context_window_fast(row['sentence_index'], row['word_index']), axis=1
)
tokenizer, model = load_model("msa")
df["predicted_lex"] = predict(tokenizer, model, df["input_text"].tolist())
๐ Citation
If you use this model in your research, please cite the following paper:
@inproceedings{saeed-habash-2025-lemmatization,
title = "Lemmatization as a Classification Task: Results from {A}rabic across Multiple Genres",
author = "Saeed, Mostafa and Habash, Nizar",
booktitle = "Proceedings of the 2025 Conference on Empirical Methods in Natural Language Processing",
year = "2025",
address = "Suzhou, China",
url = "https://aclanthology.org/2025.emnlp-main.1525/",
}
- Downloads last month
- 17