| | import numpy as np |
| | import pickle |
| | import os |
| | import torch |
| | from torch.utils.data import TensorDataset |
| | from torchvision.datasets import ImageFolder |
| | import torchvision.transforms as transforms |
| | from sklearn.model_selection import train_test_split |
| |
|
| |
|
| | def set_up_data(H): |
| | shift_loss = -127.5 |
| | scale_loss = 1. / 127.5 |
| | if H.dataset == 'imagenet32': |
| | trX, vaX, teX = imagenet32(H.data_root) |
| | H.image_size = 32 |
| | H.image_channels = 3 |
| | shift = -116.2373 |
| | scale = 1. / 69.37404 |
| | elif H.dataset == 'imagenet64': |
| | trX, vaX, teX = imagenet64(H.data_root) |
| | H.image_size = 64 |
| | H.image_channels = 3 |
| | shift = -115.92961967 |
| | scale = 1. / 69.37404 |
| | elif H.dataset == 'ffhq_256': |
| | trX, vaX, teX = ffhq256(H.data_root) |
| | H.image_size = 256 |
| | H.image_channels = 3 |
| | shift = -112.8666757481 |
| | scale = 1. / 69.84780273 |
| | elif H.dataset == 'ffhq_1024': |
| | trX, vaX, teX = ffhq1024(H.data_root) |
| | H.image_size = 1024 |
| | H.image_channels = 3 |
| | shift = -0.4387 |
| | scale = 1.0 / 0.2743 |
| | shift_loss = -0.5 |
| | scale_loss = 2.0 |
| | elif H.dataset == 'cifar10': |
| | (trX, _), (vaX, _), (teX, _) = cifar10(H.data_root, one_hot=False) |
| | H.image_size = 32 |
| | H.image_channels = 3 |
| | shift = -120.63838 |
| | scale = 1. / 64.16736 |
| | else: |
| | raise ValueError('unknown dataset: ', H.dataset) |
| |
|
| | do_low_bit = H.dataset in ['ffhq_256'] |
| |
|
| | if H.test_eval: |
| | print('DOING TEST') |
| | eval_dataset = teX |
| | else: |
| | eval_dataset = vaX |
| |
|
| | shift = torch.tensor([shift]).cuda().view(1, 1, 1, 1) |
| | scale = torch.tensor([scale]).cuda().view(1, 1, 1, 1) |
| | shift_loss = torch.tensor([shift_loss]).cuda().view(1, 1, 1, 1) |
| | scale_loss = torch.tensor([scale_loss]).cuda().view(1, 1, 1, 1) |
| |
|
| | if H.dataset == 'ffhq_1024': |
| | train_data = ImageFolder(trX, transforms.ToTensor()) |
| | valid_data = ImageFolder(eval_dataset, transforms.ToTensor()) |
| | untranspose = True |
| | else: |
| | train_data = TensorDataset(torch.as_tensor(trX)) |
| | valid_data = TensorDataset(torch.as_tensor(eval_dataset)) |
| | untranspose = False |
| |
|
| | def preprocess_func(x): |
| | nonlocal shift |
| | nonlocal scale |
| | nonlocal shift_loss |
| | nonlocal scale_loss |
| | nonlocal do_low_bit |
| | nonlocal untranspose |
| | 'takes in a data example and returns the preprocessed input' |
| | 'as well as the input processed for the loss' |
| | if untranspose: |
| | x[0] = x[0].permute(0, 2, 3, 1) |
| | inp = x[0].cuda(non_blocking=True).float() |
| | out = inp.clone() |
| | inp.add_(shift).mul_(scale) |
| | if do_low_bit: |
| | |
| | out.mul_(1. / 8.).floor_().mul_(8.) |
| | out.add_(shift_loss).mul_(scale_loss) |
| | return inp, out |
| |
|
| | return H, train_data, valid_data, preprocess_func |
| |
|
| |
|
| | def mkdir_p(path): |
| | os.makedirs(path, exist_ok=True) |
| |
|
| |
|
| | def flatten(outer): |
| | return [el for inner in outer for el in inner] |
| |
|
| |
|
| | def unpickle_cifar10(file): |
| | fo = open(file, 'rb') |
| | data = pickle.load(fo, encoding='bytes') |
| | fo.close() |
| | data = dict(zip([k.decode() for k in data.keys()], data.values())) |
| | return data |
| |
|
| |
|
| | def imagenet32(data_root): |
| | trX = np.load(os.path.join(data_root, 'imagenet32-train.npy'), mmap_mode='r') |
| | np.random.seed(42) |
| | tr_va_split_indices = np.random.permutation(trX.shape[0]) |
| | train = trX[tr_va_split_indices[:-5000]] |
| | valid = trX[tr_va_split_indices[-5000:]] |
| | test = np.load(os.path.join(data_root, 'imagenet32-valid.npy'), mmap_mode='r') |
| | return train, valid, test |
| |
|
| |
|
| | def imagenet64(data_root): |
| | trX = np.load(os.path.join(data_root, 'imagenet64-train.npy'), mmap_mode='r') |
| | np.random.seed(42) |
| | tr_va_split_indices = np.random.permutation(trX.shape[0]) |
| | train = trX[tr_va_split_indices[:-5000]] |
| | valid = trX[tr_va_split_indices[-5000:]] |
| | test = np.load(os.path.join(data_root, 'imagenet64-valid.npy'), mmap_mode='r') |
| | return train, valid, test |
| |
|
| |
|
| | def ffhq1024(data_root): |
| | |
| | return os.path.join(data_root, 'ffhq1024/train'), os.path.join(data_root, 'ffhq1024/valid'), os.path.join(data_root, 'ffhq1024/valid') |
| |
|
| |
|
| | def ffhq256(data_root): |
| | trX = np.load(os.path.join(data_root, 'ffhq-256.npy'), mmap_mode='r') |
| | np.random.seed(5) |
| | tr_va_split_indices = np.random.permutation(trX.shape[0]) |
| | train = trX[tr_va_split_indices[:-7000]] |
| | valid = trX[tr_va_split_indices[-7000:]] |
| | |
| | return train, valid, valid |
| |
|
| |
|
| | def cifar10(data_root, one_hot=True): |
| | tr_data = [unpickle_cifar10(os.path.join(data_root, 'cifar-10-batches-py/', 'data_batch_%d' % i)) for i in range(1, 6)] |
| | trX = np.vstack(data['data'] for data in tr_data) |
| | trY = np.asarray(flatten([data['labels'] for data in tr_data])) |
| | te_data = unpickle_cifar10(os.path.join(data_root, 'cifar-10-batches-py/', 'test_batch')) |
| | teX = np.asarray(te_data['data']) |
| | teY = np.asarray(te_data['labels']) |
| | trX = trX.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) |
| | teX = teX.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) |
| | trX, vaX, trY, vaY = train_test_split(trX, trY, test_size=5000, random_state=11172018) |
| | if one_hot: |
| | trY = np.eye(10, dtype=np.float32)[trY] |
| | vaY = np.eye(10, dtype=np.float32)[vaY] |
| | teY = np.eye(10, dtype=np.float32)[teY] |
| | else: |
| | trY = np.reshape(trY, [-1, 1]) |
| | vaY = np.reshape(vaY, [-1, 1]) |
| | teY = np.reshape(teY, [-1, 1]) |
| | return (trX, trY), (vaX, vaY), (teX, teY) |
| |
|