| |
|
|
| from torch.utils.data import BatchSampler, DataLoader, IterableDataset |
|
|
| |
| _PYTORCH_DATALOADER_KWARGS = { |
| "batch_size": 1, |
| "shuffle": False, |
| "sampler": None, |
| "batch_sampler": None, |
| "num_workers": 0, |
| "collate_fn": None, |
| "pin_memory": False, |
| "drop_last": False, |
| "timeout": 0, |
| "worker_init_fn": None, |
| "multiprocessing_context": None, |
| "generator": None, |
| "prefetch_factor": 2, |
| "persistent_workers": False, |
| } |
|
|
|
|
| class SkipBatchSampler(BatchSampler): |
| """ |
| A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`. |
| """ |
|
|
| def __init__(self, batch_sampler, skip_batches=0): |
| self.batch_sampler = batch_sampler |
| self.skip_batches = skip_batches |
|
|
| def __iter__(self): |
| for index, samples in enumerate(self.batch_sampler): |
| if index >= self.skip_batches: |
| yield samples |
|
|
| @property |
| def total_length(self): |
| return len(self.batch_sampler) |
|
|
| def __len__(self): |
| return len(self.batch_sampler) - self.skip_batches |
|
|
|
|
| class SkipDataLoader(DataLoader): |
| """ |
| Subclass of a PyTorch `DataLoader` that will skip the first batches. |
| |
| Args: |
| dataset (`torch.utils.data.dataset.Dataset`): |
| The dataset to use to build this datalaoder. |
| skip_batches (`int`, *optional*, defaults to 0): |
| The number of batches to skip at the beginning. |
| kwargs: |
| All other keyword arguments to pass to the regular `DataLoader` initialization. |
| """ |
|
|
| def __init__(self, dataset, skip_batches=0, **kwargs): |
| super().__init__(dataset, **kwargs) |
| self.skip_batches = skip_batches |
|
|
| def __iter__(self): |
| for index, batch in enumerate(super().__iter__()): |
| if index >= self.skip_batches: |
| yield batch |
|
|
|
|
| |
| def skip_first_batches(dataloader, num_batches=0): |
| """ |
| Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. |
| """ |
| dataset = dataloader.dataset |
| sampler_is_batch_sampler = False |
| if isinstance(dataset, IterableDataset): |
| new_batch_sampler = None |
| else: |
| sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) |
| batch_sampler = ( |
| dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler |
| ) |
| new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches) |
|
|
| |
| ignore_kwargs = [ |
| "batch_size", |
| "shuffle", |
| "sampler", |
| "batch_sampler", |
| "drop_last", |
| ] |
|
|
| kwargs = { |
| k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) |
| for k in _PYTORCH_DATALOADER_KWARGS |
| if k not in ignore_kwargs |
| } |
|
|
| |
| if new_batch_sampler is None: |
| kwargs["drop_last"] = dataloader.drop_last |
| kwargs["batch_size"] = dataloader.batch_size |
|
|
| if new_batch_sampler is None: |
| |
| dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) |
| else: |
| dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) |
|
|
| return dataloader |
|
|