| | import torchvision.datasets as dset |
| | from torch.utils.data import Dataset |
| | import torch |
| | from torch.utils.data import DataLoader |
| | import glob |
| | import os |
| | from transformers import AutoTokenizer |
| | from torch.utils.data import Dataset, DataLoader, random_split |
| |
|
| |
|
| | class GithubDataset(Dataset): |
| | def __init__( |
| | self, |
| | root_dir=os.path.expanduser("~/torch_datasets/github-python/corpus"), |
| | train=False, |
| | max_length=512, |
| | ): |
| | self.root = root_dir |
| | self.file_list = glob.glob(os.path.join(root_dir, "*.*")) |
| | self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
| | self.max_length = max_length |
| |
|
| | def __len__(self): |
| | return len(self.file_list) |
| |
|
| | def __getitem__(self, idx): |
| |
|
| | path = self.file_list[idx] |
| |
|
| | with open(path, "r", encoding="utf-8", errors="ignore") as file: |
| | code = file.read() |
| |
|
| | encoding = self.tokenizer( |
| | code, |
| | padding="max_length", |
| | truncation=True, |
| | max_length=self.max_length, |
| | return_tensors="pt", |
| | ) |
| |
|
| | input_ids = encoding["input_ids"].squeeze(0) |
| | attention_mask = encoding["attention_mask"].squeeze(0) |
| |
|
| | |
| |
|
| | return input_ids, attention_mask |
| |
|
| |
|
| | dataset = GithubDataset() |
| | dataset = GithubDataset(root_dir="./test-data/") |
| | train_size = int(0.8 * len(dataset)) |
| | test_size = len(dataset) - train_size |
| |
|
| | train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) |
| |
|
| |
|
| | def get_train_dataset(): |
| | return train_dataset |
| |
|
| |
|
| | def get_test_dataset(): |
| |
|
| | return test_dataset |
| |
|
| |
|
| | def get_dataloader(dataset, batch_size=64): |
| |
|
| | return DataLoader(dataset, batch_size=batch_size, shuffle=True) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | d = get_train_dataset() |
| | print("Number of samples: ", len(d)) |
| |
|
| | a, b = d[4] |
| | t = AutoTokenizer.from_pretrained("bert-base-uncased") |
| | for i in a: |
| | print(t.decode(i.item()), end=" ") |
| | print() |
| |
|