| |
| |
| |
|
|
| from datasets import load_dataset |
| from transformers import AutoTokenizer, T5ForConditionalGeneration |
| from trl import SFTTrainer, SFTConfig |
|
|
| dataset = load_dataset("mindchain/container-status-de", split="train") |
| split = dataset.train_test_split(test_size=0.15, seed=42) |
|
|
| def fmt(ex): |
| return {"text": f"Status: {ex['text']}", "label": ex["label"]} |
|
|
| train_ds = split["train"].map(fmt, remove_columns=split["train"].column_names) |
| eval_ds = split["test"].map(fmt, remove_columns=split["test"].column_names) |
|
|
| tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-2-270m") |
| model = T5ForConditionalGeneration.from_pretrained("google/t5gemma-2-270m") |
|
|
| config = SFTConfig( |
| output_dir="out", |
| push_to_hub=True, |
| hub_model_id="mindchain/t5gemma-270m-container-status", |
| num_train_epochs=5, |
| per_device_train_batch_size=2, |
| gradient_accumulation_steps=4, |
| learning_rate=3e-4, |
| logging_steps=5, |
| max_length=256, |
| report_to="trackio", |
| ) |
|
|
| trainer = SFTTrainer(model=model, tokenizer=tokenizer, train_dataset=train_ds, eval_dataset=eval_ds, args=config) |
| trainer.train() |
| trainer.push_to_hub() |
| print('DONE') |
|
|