| |
|
| | import torch |
| | from torch.utils.data import DataLoader, Dataset |
| | from datasets import load_dataset |
| | from transformers import AutoTokenizer |
| |
|
| | def train_tokenizer(texts, vocab_size=50000, min_frequency=2): |
| | tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| | tokenizer = tokenizer.train_new_from_iterator(texts, vocab_size=vocab_size, min_frequency=min_frequency) |
| | if tokenizer.pad_token is None: |
| | tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| | tokenizer.save_pretrained("./tokenizer") |
| | return tokenizer |
| |
|
| | def load_tokenizer(): |
| | tokenizer = AutoTokenizer.from_pretrained("./tokenizer") |
| | if tokenizer.pad_token is None: |
| | tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| | return tokenizer |
| |
|
| | class TextDataset(Dataset): |
| | def __init__(self, texts, tokenizer, max_length): |
| | self.texts = texts |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| |
|
| | def __len__(self): |
| | return len(self.texts) |
| |
|
| | def __getitem__(self, idx): |
| | text = self.texts[idx] |
| | encodings = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length) |
| | return torch.tensor(encodings['input_ids']) |
| |
|
| | def get_dataloader(dataset_name, config_name, tokenizer, max_length, batch_size): |
| | dataset = load_dataset(dataset_name, config_name) |
| | texts = dataset['train']['text'][:50] |
| | dataset = TextDataset(texts, tokenizer, max_length) |
| | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
| | return dataloader |
| |
|