File size: 5,466 Bytes
68b32f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.distributed as dist
from torch.utils.data import Sampler
import math
import itertools
import numpy as np

class FastRandomDistributedSampler(Sampler[int]):
    r"""
    A distributed sampler that continuously yields random indices with replacement,
    avoiding frequent iterator recreation overhead for DataLoader.

    Instead of stopping after one pass through the dataset, this sampler's
    iterator yields a specified number of indices (`epoch_steps`) before
    stopping. This significantly reduces the frequency of DataLoader worker
    restarts when the underlying dataset is small.

    Args:
        dataset: Dataset used for sampling.
        num_replicas (int, optional): Number of processes participating in
            distributed training. Defaults to current world size.
        rank (int, optional): Rank of the current process. Defaults to current rank.
        seed (int): Base seed for the random number generator. Each epoch/rank
                  gets a different derived seed. Defaults to 0.
        epoch_steps (int): The number of indices this sampler should yield per
                         __iter__ call (per replica). Set this to a large number
                         to reduce iterator recreation frequency. If None, it defaults
                         to ceil(len(dataset) / num_replicas).
    """
    def __init__(self, dataset, num_replicas=None, rank=None, seed=0, epoch_steps=None):
        if num_replicas is None:
            if not dist.is_available() or not dist.is_initialized():
                raise RuntimeError("Requires distributed package to be available and initialized")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available() or not dist.is_initialized():
                raise RuntimeError("Requires distributed package to be available and initialized")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")

        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.seed = seed
        self.epoch = 0
        self.dataset_len = len(self.dataset)

        # Determine the number of steps/indices per iterator cycle for this rank
        if epoch_steps is None:
            # Default behavior: roughly one pass over the data
            self.num_samples_per_epoch = math.ceil(self.dataset_len / self.num_replicas)
        else:
            # User-defined length for the iterator cycle
            self.num_samples_per_epoch = epoch_steps

        if not isinstance(self.num_samples_per_epoch, int) or self.num_samples_per_epoch <= 0:
            raise ValueError("epoch_steps must be a positive integer")

    def _infinite_indices(self):
        """A generator that yields random indices indefinitely."""
        g = torch.Generator()
        # Ensure distinct seeds based on rank, epoch, and base seed
        current_seed = self.seed + self.epoch * self.num_replicas + self.rank
        g.manual_seed(current_seed)
        while True:
            yield torch.randint(low=0, high=self.dataset_len, size=(1,), generator=g).item()

    def __iter__(self):
        """
        Returns an iterator that yields 'num_samples_per_epoch' indices.
        It uses itertools.islice to take a finite slice from the
        infinite generator, avoiding expensive list creation.
        """
        # Create the infinite generator and slice it
        # The generator state is preserved across calls to next() by the DataLoader
        # The expensive DataLoader setup only happens when this sliced iterator is exhausted
        return itertools.islice(self._infinite_indices(), self.num_samples_per_epoch)

    def __len__(self):
        """The number of samples produced by the iterator per __iter__ call."""
        return self.num_samples_per_epoch

    def set_epoch(self, epoch: int) -> None:
        """
        Sets the epoch for this sampler. This is used to vary the random seed sequence
        each time __iter__ is called.
        """
        self.epoch = epoch

class QAMNISTSampler(Sampler):
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_samples = len(dataset)

    def __iter__(self):
        indices = torch.randperm(self.num_samples).tolist()
        for i in range(0, self.num_samples, self.batch_size):
            batch_indices = indices[i:i + self.batch_size]
            
            if self.dataset.num_images_range[0] == self.dataset.num_images_range[1]:
                batch_num_digits = self.dataset.num_images_range[0]
            else:
                batch_num_digits = np.random.randint(self.dataset.num_images_range[0], self.dataset.num_images_range[1])

            if self.dataset.num_operations_range[0] == self.dataset.num_operations_range[1]:
                batch_num_operations = self.dataset.num_operations_range[0]
            else:
                batch_num_operations = np.random.randint(self.dataset.num_operations_range[0], self.dataset.num_operations_range[1])

            self.dataset.set_num_digits(batch_num_digits)
            self.dataset.set_num_operations(batch_num_operations)
            
            yield batch_indices

    def __len__(self):
        return self.num_samples // self.batch_size