| |
| |
| |
| import torch |
| from torchvision import datasets |
| from torchvision.transforms import v2 |
| from torch.utils.data import DataLoader |
|
|
| import utils |
| from typing import Tuple |
|
|
| import io |
| import base64 |
| from PIL import Image |
| import numpy as np |
|
|
| |
| BASE_TRANSFORMS = v2.Compose([ |
| v2.ToImage(), |
| v2.ToDtype(torch.float32, scale = True), |
| v2.Normalize(mean = [0.1307], std = [0.3081]) |
| ]) |
|
|
| TRAIN_TRANSFORMS = v2.Compose([ |
| v2.RandomAffine(degrees = 15, |
| scale = (0.8, 1.2), |
| translate = (0.08, 0.08), |
| shear = 10), |
| v2.ToImage(), |
| v2.ToDtype(torch.float32, scale = True), |
| v2.Normalize(mean = [0.1307], std = [0.3081]), |
| ]) |
|
|
|
|
| |
| |
| |
| def get_dataloaders(root: str, |
| batch_size: int, |
| num_workers: int = 0) -> Tuple[DataLoader, DataLoader]: |
| ''' |
| Creates training and testing dataloaders for the MNIST dataset |
| |
| Args: |
| root (str): Path to download MNIST data. |
| batch_size (int): Size used to split training and testing datasets into batches. |
| num_workers (int): Number of workers to use for multiprocessing. Default is 0. |
| ''' |
|
|
| |
| mnist_train = datasets.MNIST(root, download = True, train = True, |
| transform = TRAIN_TRANSFORMS) |
| mnist_test = datasets.MNIST(root, download = True, train = False, |
| transform = BASE_TRANSFORMS) |
|
|
| |
| if num_workers > 0: |
| mp_context = utils.MP_CONTEXT |
| persistent_workers = True |
| else: |
| mp_context = None |
| persistent_workers = False |
|
|
| train_dl = DataLoader( |
| dataset = mnist_train, |
| batch_size = batch_size, |
| shuffle = True, |
| num_workers = num_workers, |
| multiprocessing_context = mp_context, |
| pin_memory = utils.PIN_MEM, |
| persistent_workers = persistent_workers |
| ) |
|
|
| test_dl = DataLoader( |
| dataset = mnist_test, |
| batch_size = batch_size, |
| shuffle = False, |
| num_workers = num_workers, |
| multiprocessing_context = mp_context, |
| pin_memory = utils.PIN_MEM, |
| persistent_workers = persistent_workers |
| ) |
|
|
| return train_dl, test_dl |
|
|
| def mnist_preprocess(uri: str): |
| ''' |
| Preprocesses a data URI representing a handwritten digit image according to the pipeline used in the MNIST dataset. |
| The pipeline includes: |
| 1. Converting the image to grayscale. |
| 2. Resizing the image to 20x20, preserving the aspect ratio, and using anti-aliasing. |
| 3. Centering the resized image in a 28x28 image based on the center of mass (COM). |
| 4. Converting the image to a tensor (pixel values between 0 and 1) and normalizing it using MNIST statistics. |
| |
| Reference: https://paperswithcode.com/dataset/mnist |
| |
| Args: |
| uri (str): A string representing the full data URI. |
| |
| Returns: |
| Tensor: A tensor of shape (1, 28, 28) representing the preprocessed image, normalized using MNIST statistics. |
| ''' |
| encoded_img = uri.split(',', 1)[1] |
| image_bytes = io.BytesIO(base64.b64decode(encoded_img)) |
| pil_img = Image.open(image_bytes).convert('L') |
| |
| |
| pil_img.thumbnail((20, 20), Image.Resampling.LANCZOS) |
|
|
| |
| img = 255 - np.array(pil_img) |
|
|
| |
| img_idxs = np.indices(img.shape) |
| tot_mass = img.sum() |
| |
| |
| com_x = np.round((img_idxs[1] * img).sum() / tot_mass).astype(int) |
| com_y = np.round((img_idxs[0] * img).sum() / tot_mass).astype(int) |
| |
| dist_com_end_x = img.shape[1] - com_x |
| dist_com_end_y = img.shape[0] - com_y |
| |
| new_img = np.zeros((28, 28), dtype = np.uint8) |
| new_com_x, new_com_y = 14, 14 |
| |
| valid_start_x = min(new_com_x, com_x) |
| valid_end_x = min(14, dist_com_end_x) |
| valid_start_y = min(new_com_y, com_y) |
| valid_end_y = min(14, dist_com_end_y) |
| |
| old_slice_x = slice(com_x - valid_start_x, com_x + valid_end_x) |
| old_slice_y = slice(com_y - valid_start_y, com_y + valid_end_y) |
| new_slice_x = slice(new_com_x - valid_start_x, new_com_x + valid_end_x) |
| new_slice_y = slice(new_com_y - valid_start_y, new_com_y + valid_end_y) |
|
|
| |
| new_img[new_slice_y, new_slice_x] = img[old_slice_y, old_slice_x] |
|
|
| |
| return BASE_TRANSFORMS(new_img) |
|
|