| | import os |
| | import time |
| | import torch |
| | from transformers import GPT2LMHeadModel, GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments |
| |
|
| | class GptHumorTrainer: |
| |
|
| | def __init__(self, silent=False) -> None: |
| | start_time = time.perf_counter() |
| | self.tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") |
| | self.model = GPT2LMHeadModel.from_pretrained(self.local_file_path("SaveState")) |
| | self.model.eval() |
| | if not silent: |
| | print(f"Model Loading Took {time.perf_counter()-start_time} Seconds") |
| |
|
| | def local_file_path(self, path): |
| | return os.path.join(os.path.dirname(os.path.abspath(__file__)), path) |
| |
|
| | def train(self, train_file, epochs=3): |
| | device = torch.device("cpu") |
| | self.model.to(device) |
| |
|
| | |
| | train_dataset = TextDataset( |
| | tokenizer=self.tokenizer, |
| | file_path=train_file, |
| | block_size=128, |
| | ) |
| |
|
| | |
| | data_collator = DataCollatorForLanguageModeling( |
| | tokenizer=self.tokenizer, |
| | mlm=False, |
| | ) |
| |
|
| | for epoch in range(epochs): |
| | |
| | training_args = TrainingArguments( |
| | output_dir=f"./results/epoch_{epoch+1}", |
| | overwrite_output_dir=True, |
| | num_train_epochs=3, |
| | per_device_train_batch_size=3, |
| | save_steps=-1, |
| | save_total_limit=None, |
| | prediction_loss_only=True, |
| | ) |
| |
|
| | |
| | trainer = Trainer( |
| | model=self.model, |
| | args=training_args, |
| | data_collator=data_collator, |
| | train_dataset=train_dataset, |
| | ) |
| |
|
| | |
| | trainer.train() |
| |
|
| | |
| | self.model.save_pretrained(self.local_file_path("SaveState")) |
| |
|
| | if __name__ == "__main__": |
| | humor_trainer = GptHumorTrainer() |
| | humor_trainer.train(humor_trainer.local_file_path("TrainData.txt"), epochs=5) |
| |
|