time-period-api / train.py
DelaliScratchwerk's picture
Update train.py
952742f verified
#!/usr/bin/env python3
import os
from dataclasses import dataclass
from typing import Dict, List
import numpy as np
from datasets import load_dataset
import evaluate
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
DataCollatorWithPadding,
TrainingArguments,
Trainer,
)
# ======================
# LABEL SCHEMA
# ======================
LABELS: List[str] = [
"pre-1900",
"1900-1945",
"1946-1979",
"1980-1999",
"2000-2015",
"2016-present",
]
id2label: Dict[int, str] = {i: l for i, l in enumerate(LABELS)}
label2id: Dict[str, int] = {l: i for i, l in enumerate(LABELS)}
# Base model to fine-tune
BASE_MODEL = os.environ.get("BASE_MODEL", "distilroberta-base")
# Hugging Face hub repo where the fine-tuned model will be pushed
HUB_MODEL_ID = "DelaliScratchwerk/time-period-classifier-bert"
# ======================
# LOAD DATA
# ======================
# Expect CSVs at data/train.csv and data/val.csv
dataset = load_dataset(
"csv",
data_files={
"train": "data/train.csv",
"validation": "data/val.csv",
},
)
print("Raw dataset:", dataset)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
def encode_batch(batch):
# tokenize texts
enc = tokenizer(batch["text"], truncation=True)
# map string labels -> integer ids
# strip helps if there are trailing spaces in the CSV
enc["labels"] = [label2id[l.strip()] for l in batch["label"]]
return enc
# IMPORTANT: remove original 'text' and 'label' columns so Trainer only sees tensors
encoded = dataset.map(
encode_batch,
batched=True,
remove_columns=dataset["train"].column_names,
)
print(encoded)
print("Encoded train sample keys:", encoded["train"][0].keys())
# should be: dict_keys(['input_ids', 'attention_mask', 'labels'])
# ======================
# MODEL
# ======================
model = AutoModelForSequenceClassification.from_pretrained(
BASE_MODEL,
num_labels=len(LABELS),
id2label=id2label,
label2id=label2id,
)
# ======================
# METRICS
# ======================
accuracy = evaluate.load("accuracy")
f1_macro = evaluate.load("f1")
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
return {
"accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"],
"f1_macro": f1_macro.compute(
predictions=preds, references=labels, average="macro"
)["f1"],
}
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# ======================
# TRAINING ARGS
# ======================
training_args = TrainingArguments(
output_dir="out",
per_device_train_batch_size=16,
per_device_eval_batch_size=32,
learning_rate=5e-5,
num_train_epochs=10,
eval_strategy="epoch",
save_strategy="no",
load_best_model_at_end=False,
logging_steps=50,
push_to_hub=True,
hub_model_id=HUB_MODEL_ID,
hub_private_repo=False,
)
# ======================
# TRAINER
# ======================
trainer = Trainer(
model=model,
args=training_args,
train_dataset=encoded["train"],
eval_dataset=encoded["validation"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
if __name__ == "__main__":
trainer.train()
# push best model + tokenizer to the Hub
trainer.push_to_hub()
tokenizer.push_to_hub(HUB_MODEL_ID)