File size: 5,466 Bytes
68b32f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import torch
import torch.distributed as dist
from torch.utils.data import Sampler
import math
import itertools
import numpy as np
class FastRandomDistributedSampler(Sampler[int]):
r"""
A distributed sampler that continuously yields random indices with replacement,
avoiding frequent iterator recreation overhead for DataLoader.
Instead of stopping after one pass through the dataset, this sampler's
iterator yields a specified number of indices (`epoch_steps`) before
stopping. This significantly reduces the frequency of DataLoader worker
restarts when the underlying dataset is small.
Args:
dataset: Dataset used for sampling.
num_replicas (int, optional): Number of processes participating in
distributed training. Defaults to current world size.
rank (int, optional): Rank of the current process. Defaults to current rank.
seed (int): Base seed for the random number generator. Each epoch/rank
gets a different derived seed. Defaults to 0.
epoch_steps (int): The number of indices this sampler should yield per
__iter__ call (per replica). Set this to a large number
to reduce iterator recreation frequency. If None, it defaults
to ceil(len(dataset) / num_replicas).
"""
def __init__(self, dataset, num_replicas=None, rank=None, seed=0, epoch_steps=None):
if num_replicas is None:
if not dist.is_available() or not dist.is_initialized():
raise RuntimeError("Requires distributed package to be available and initialized")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available() or not dist.is_initialized():
raise RuntimeError("Requires distributed package to be available and initialized")
rank = dist.get_rank()
if rank >= num_replicas or rank < 0:
raise ValueError(
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.seed = seed
self.epoch = 0
self.dataset_len = len(self.dataset)
# Determine the number of steps/indices per iterator cycle for this rank
if epoch_steps is None:
# Default behavior: roughly one pass over the data
self.num_samples_per_epoch = math.ceil(self.dataset_len / self.num_replicas)
else:
# User-defined length for the iterator cycle
self.num_samples_per_epoch = epoch_steps
if not isinstance(self.num_samples_per_epoch, int) or self.num_samples_per_epoch <= 0:
raise ValueError("epoch_steps must be a positive integer")
def _infinite_indices(self):
"""A generator that yields random indices indefinitely."""
g = torch.Generator()
# Ensure distinct seeds based on rank, epoch, and base seed
current_seed = self.seed + self.epoch * self.num_replicas + self.rank
g.manual_seed(current_seed)
while True:
yield torch.randint(low=0, high=self.dataset_len, size=(1,), generator=g).item()
def __iter__(self):
"""
Returns an iterator that yields 'num_samples_per_epoch' indices.
It uses itertools.islice to take a finite slice from the
infinite generator, avoiding expensive list creation.
"""
# Create the infinite generator and slice it
# The generator state is preserved across calls to next() by the DataLoader
# The expensive DataLoader setup only happens when this sliced iterator is exhausted
return itertools.islice(self._infinite_indices(), self.num_samples_per_epoch)
def __len__(self):
"""The number of samples produced by the iterator per __iter__ call."""
return self.num_samples_per_epoch
def set_epoch(self, epoch: int) -> None:
"""
Sets the epoch for this sampler. This is used to vary the random seed sequence
each time __iter__ is called.
"""
self.epoch = epoch
class QAMNISTSampler(Sampler):
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size
self.num_samples = len(dataset)
def __iter__(self):
indices = torch.randperm(self.num_samples).tolist()
for i in range(0, self.num_samples, self.batch_size):
batch_indices = indices[i:i + self.batch_size]
if self.dataset.num_images_range[0] == self.dataset.num_images_range[1]:
batch_num_digits = self.dataset.num_images_range[0]
else:
batch_num_digits = np.random.randint(self.dataset.num_images_range[0], self.dataset.num_images_range[1])
if self.dataset.num_operations_range[0] == self.dataset.num_operations_range[1]:
batch_num_operations = self.dataset.num_operations_range[0]
else:
batch_num_operations = np.random.randint(self.dataset.num_operations_range[0], self.dataset.num_operations_range[1])
self.dataset.set_num_digits(batch_num_digits)
self.dataset.set_num_operations(batch_num_operations)
yield batch_indices
def __len__(self):
return self.num_samples // self.batch_size |