Spaces:
Runtime error
Runtime error
| from torch.utils.data import DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| import spiga.data.loaders.alignments as zoo_alignments | |
| zoos = [zoo_alignments] | |
| def get_dataset(data_config, pretreat=None, debug=False): | |
| for zoo in zoos: | |
| dataset = zoo.get_dataset(data_config, pretreat=pretreat, debug=debug) | |
| if dataset is not None: | |
| return dataset | |
| raise NotImplementedError('Dataset not available') | |
| def get_dataloader(batch_size, data_config, pretreat=None, sampler_cfg=None, debug=False): | |
| dataset = get_dataset(data_config, pretreat=pretreat, debug=debug) | |
| if (len(dataset) % batch_size) == 1 and data_config.shuffle == True: | |
| drop_last_batch = True | |
| else: | |
| drop_last_batch = False | |
| shuffle = data_config.shuffle | |
| sampler = None | |
| if sampler_cfg is not None: | |
| sampler = DistributedSampler(dataset, num_replicas=sampler_cfg.world_size, rank=sampler_cfg.rank) | |
| shuffle = False | |
| dataloader = DataLoader(dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle, | |
| num_workers=data_config.num_workers, | |
| pin_memory=True, | |
| drop_last=drop_last_batch, | |
| sampler=sampler) | |
| return dataloader, dataset | |