| import json |
| import logging |
| import random |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
|
|
| from src.data.augmentations import ( |
| NanAugmenter, |
| ) |
| from src.data.constants import DEFAULT_NAN_STATS_PATH, LENGTH_CHOICES, LENGTH_WEIGHTS |
| from src.data.containers import BatchTimeSeriesContainer |
| from src.data.datasets import CyclicalBatchDataset |
| from src.data.frequency import Frequency |
| from src.data.scalers import MeanScaler, MedianScaler, MinMaxScaler, RobustScaler |
| from src.data.utils import sample_future_length |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class BatchComposer: |
| """ |
| Composes batches from saved generator data according to specified proportions. |
| Manages multiple CyclicalBatchDataset instances and creates uniform or mixed batches. |
| """ |
|
|
| def __init__( |
| self, |
| base_data_dir: str, |
| generator_proportions: dict[str, float] | None = None, |
| mixed_batches: bool = True, |
| device: torch.device | None = None, |
| augmentations: dict[str, bool] | None = None, |
| augmentation_probabilities: dict[str, float] | None = None, |
| nan_stats_path: str | None = None, |
| nan_patterns_path: str | None = None, |
| global_seed: int = 42, |
| chosen_scaler_name: str | None = None, |
| rank: int = 0, |
| world_size: int = 1, |
| ): |
| """ |
| Initialize the BatchComposer. |
| |
| Args: |
| base_data_dir: Base directory containing generator subdirectories |
| generator_proportions: Dict mapping generator names to proportions |
| mixed_batches: If True, create mixed batches; if False, uniform batches |
| device: Device to load tensors to |
| augmentations: Dict mapping augmentation names to booleans |
| augmentation_probabilities: Dict mapping augmentation names to probabilities |
| global_seed: Global random seed |
| chosen_scaler_name: Name of the scaler that used in training |
| rank: Rank of current process for distributed data loading |
| world_size: Total number of processes for distributed data loading |
| """ |
| self.base_data_dir = base_data_dir |
| self.mixed_batches = mixed_batches |
| self.device = device |
| self.global_seed = global_seed |
| self.nan_stats_path = nan_stats_path |
| self.nan_patterns_path = nan_patterns_path |
| self.rank = rank |
| self.world_size = world_size |
| self.augmentation_probabilities = augmentation_probabilities or { |
| "noise_augmentation": 0.3, |
| "scaler_augmentation": 0.5, |
| } |
| |
| self.chosen_scaler_name = chosen_scaler_name.lower() if chosen_scaler_name is not None else None |
|
|
| |
| self.rng = np.random.default_rng(global_seed) |
| random.seed(global_seed) |
| torch.manual_seed(global_seed) |
|
|
| |
| self._setup_augmentations(augmentations) |
|
|
| |
| self._setup_proportions(generator_proportions) |
|
|
| |
| self.datasets = self._initialize_datasets() |
|
|
| logger.info( |
| f"Initialized BatchComposer with {len(self.datasets)} generators, " |
| f"mixed_batches={mixed_batches}, proportions={self.generator_proportions}, " |
| f"augmentations={self.augmentations}, " |
| f"augmentation_probabilities={self.augmentation_probabilities}" |
| ) |
|
|
| def _setup_augmentations(self, augmentations: dict[str, bool] | None): |
| """Setup only the augmentations that should remain online (NaN).""" |
| default_augmentations = { |
| "nan_augmentation": False, |
| "scaler_augmentation": False, |
| "length_shortening": False, |
| } |
|
|
| self.augmentations = augmentations or default_augmentations |
|
|
| |
| self.nan_augmenter = None |
| if self.augmentations.get("nan_augmentation", False): |
| stats_path_to_use = self.nan_stats_path or DEFAULT_NAN_STATS_PATH |
| stats = json.load(open(stats_path_to_use)) |
| self.nan_augmenter = NanAugmenter( |
| p_series_has_nan=stats["p_series_has_nan"], |
| nan_ratio_distribution=stats["nan_ratio_distribution"], |
| nan_length_distribution=stats["nan_length_distribution"], |
| nan_patterns_path=self.nan_patterns_path, |
| ) |
|
|
| def _should_apply_scaler_augmentation(self) -> bool: |
| """ |
| Decide whether to apply scaler augmentation for a single series based on |
| the boolean toggle and probability from the configuration. |
| """ |
| if not self.augmentations.get("scaler_augmentation", False): |
| return False |
| probability = float(self.augmentation_probabilities.get("scaler_augmentation", 0.0)) |
| probability = max(0.0, min(1.0, probability)) |
| return bool(self.rng.random() < probability) |
|
|
| def _choose_random_scaler(self) -> object | None: |
| """ |
| Choose a random scaler for augmentation, explicitly avoiding the one that |
| is already selected in the training configuration (if any). |
| |
| Returns an instance of the selected scaler or None when no valid option exists. |
| """ |
| chosen: str | None = None |
| if self.chosen_scaler_name is not None: |
| chosen = self.chosen_scaler_name.strip().lower() |
|
|
| candidates = ["custom_robust", "minmax", "median", "mean"] |
|
|
| |
| if chosen in candidates: |
| candidates = [c for c in candidates if c != chosen] |
| if not candidates: |
| return None |
|
|
| pick = str(self.rng.choice(candidates)) |
| if pick == "custom_robust": |
| return RobustScaler() |
| if pick == "minmax": |
| return MinMaxScaler() |
| if pick == "median": |
| return MedianScaler() |
| if pick == "mean": |
| return MeanScaler() |
| return None |
|
|
| def _setup_proportions(self, generator_proportions): |
| """Setup default or custom generator proportions.""" |
| default_proportions = { |
| "forecast_pfn": 1.0, |
| "gp": 1.0, |
| "kernel": 1.0, |
| "sinewave": 1.0, |
| "sawtooth": 1.0, |
| "step": 0.1, |
| "anomaly": 1.0, |
| "spike": 2.0, |
| "cauker_univariate": 2.0, |
| "cauker_multivariate": 0.00, |
| "lmc": 0.00, |
| "ou_process": 1.0, |
| "audio_financial_volatility": 0.1, |
| "audio_multi_scale_fractal": 0.1, |
| "audio_network_topology": 0.5, |
| "audio_stochastic_rhythm": 1.0, |
| "augmented_per_sample_2048": 3.0, |
| "augmented_temp_batch_2048": 3.0, |
| } |
| self.generator_proportions = generator_proportions or default_proportions |
|
|
| |
| total = sum(self.generator_proportions.values()) |
| if total <= 0: |
| raise ValueError("Total generator proportions must be positive") |
| self.generator_proportions = {k: v / total for k, v in self.generator_proportions.items()} |
|
|
| def _initialize_datasets(self) -> dict[str, CyclicalBatchDataset]: |
| """Initialize CyclicalBatchDataset for each generator with proportion > 0.""" |
| datasets = {} |
|
|
| for generator_name, proportion in self.generator_proportions.items(): |
| |
| if proportion <= 0: |
| logger.info(f"Skipping {generator_name} (proportion = {proportion})") |
| continue |
|
|
| batches_dir = f"{self.base_data_dir}/{generator_name}" |
|
|
| try: |
| dataset = CyclicalBatchDataset( |
| batches_dir=batches_dir, |
| generator_type=generator_name, |
| device=None, |
| prefetch_next=True, |
| prefetch_threshold=32, |
| rank=self.rank, |
| world_size=self.world_size, |
| ) |
| datasets[generator_name] = dataset |
| logger.info(f"Loaded dataset for {generator_name} (proportion = {proportion})") |
|
|
| except Exception as e: |
| logger.warning(f"Failed to load dataset for {generator_name}: {e}") |
| continue |
|
|
| if not datasets: |
| raise ValueError(f"No valid datasets found in {self.base_data_dir} or all generators have proportion <= 0") |
|
|
| return datasets |
|
|
| def _convert_sample_to_tensors( |
| self, sample: dict, future_length: int | None = None |
| ) -> tuple[torch.Tensor, np.datetime64, Frequency]: |
| """ |
| Convert a sample dict to tensors and metadata. |
| |
| Args: |
| sample: Sample dict from CyclicalBatchDataset |
| future_length: Desired future length (if None, use default split) |
| |
| Returns: |
| Tuple of (history_values, future_values, start, frequency) |
| """ |
| |
| num_channels = sample.get("num_channels", 1) |
| values_data = sample["values"] |
| generator_type = sample.get("generator_type", "unknown") |
|
|
| if num_channels == 1: |
| |
| if isinstance(values_data[0], list): |
| |
| values = torch.tensor(values_data[0], dtype=torch.float32) |
| logger.debug(f"{generator_type}: Using new univariate format, shape: {values.shape}") |
| else: |
| |
| values = torch.tensor(values_data, dtype=torch.float32) |
| values = values.unsqueeze(0).unsqueeze(-1) |
| else: |
| |
| channel_tensors = [] |
| for channel_values in values_data: |
| channel_tensor = torch.tensor(channel_values, dtype=torch.float32) |
| channel_tensors.append(channel_tensor) |
|
|
| |
| values = torch.stack(channel_tensors, dim=-1).unsqueeze(0) |
| logger.debug(f"{generator_type}: Using multivariate format, {num_channels} channels, shape: {values.shape}") |
|
|
| |
| freq_str = sample["frequency"] |
| try: |
| frequency = Frequency(freq_str) |
| except ValueError: |
| |
| freq_mapping = { |
| "h": Frequency.H, |
| "D": Frequency.D, |
| "W": Frequency.W, |
| "M": Frequency.M, |
| "Q": Frequency.Q, |
| "A": Frequency.A, |
| "Y": Frequency.A, |
| "1min": Frequency.T1, |
| "5min": Frequency.T5, |
| "10min": Frequency.T10, |
| "15min": Frequency.T15, |
| "30min": Frequency.T30, |
| "s": Frequency.S, |
| } |
| frequency = freq_mapping.get(freq_str, Frequency.H) |
|
|
| |
| if isinstance(sample["start"], pd.Timestamp): |
| start = sample["start"].to_numpy() |
| else: |
| start = np.datetime64(sample["start"]) |
|
|
| return values, start, frequency |
|
|
| def _effective_proportions_for_length(self, total_length_for_batch: int) -> dict[str, float]: |
| """ |
| Build a simple, length-aware proportion map for the current batch. |
| |
| Rules: |
| - For generators named 'augmented{L}', keep only the one matching the |
| chosen length L; zero out others. |
| - Keep non-augmented generators as-is. |
| - Drop generators that are unavailable (not loaded) or zero-weight. |
| - If nothing remains, fall back to 'augmented{L}' if available, else any dataset. |
| - Normalize the final map to sum to 1. |
| """ |
|
|
| def augmented_length_from_name(name: str) -> int | None: |
| if not name.startswith("augmented"): |
| return None |
| suffix = name[len("augmented") :] |
| if not suffix: |
| return None |
| try: |
| return int(suffix) |
| except ValueError: |
| return None |
|
|
| |
| adjusted: dict[str, float] = {} |
| for name, proportion in self.generator_proportions.items(): |
| aug_len = augmented_length_from_name(name) |
| if aug_len is None: |
| adjusted[name] = proportion |
| else: |
| adjusted[name] = proportion if aug_len == total_length_for_batch else 0.0 |
|
|
| |
| adjusted = {name: p for name, p in adjusted.items() if name in self.datasets and p > 0.0} |
|
|
| |
| if not adjusted: |
| preferred = f"augmented{total_length_for_batch}" |
| if preferred in self.datasets: |
| adjusted = {preferred: 1.0} |
| elif self.datasets: |
| |
| first_key = next(iter(self.datasets.keys())) |
| adjusted = {first_key: 1.0} |
| else: |
| raise ValueError("No datasets available to create batch") |
|
|
| |
| total = sum(adjusted.values()) |
| return {name: p / total for name, p in adjusted.items()} |
|
|
| def _compute_sample_counts_for_batch(self, proportions: dict[str, float], batch_size: int) -> dict[str, int]: |
| """ |
| Convert a proportion map into integer sample counts that sum to batch_size. |
| |
| Strategy: allocate floor(batch_size * p) to each generator in order, and let the |
| last generator absorb any remainder to ensure the total matches exactly. |
| """ |
| counts: dict[str, int] = {} |
| remaining = batch_size |
| names = list(proportions.keys()) |
| values = list(proportions.values()) |
| for index, (name, p) in enumerate(zip(names, values, strict=True)): |
| if index == len(names) - 1: |
| counts[name] = remaining |
| else: |
| n = int(batch_size * p) |
| counts[name] = n |
| remaining -= n |
| return counts |
|
|
| def _calculate_generator_samples(self, batch_size: int) -> dict[str, int]: |
| """ |
| Calculate the number of samples each generator should contribute. |
| |
| Args: |
| batch_size: Total batch size |
| |
| Returns: |
| Dict mapping generator names to sample counts |
| """ |
| generator_samples = {} |
| remaining_samples = batch_size |
|
|
| generators = list(self.generator_proportions.keys()) |
| proportions = list(self.generator_proportions.values()) |
|
|
| |
| for i, (generator, proportion) in enumerate(zip(generators, proportions, strict=True)): |
| if generator not in self.datasets: |
| continue |
|
|
| if i == len(generators) - 1: |
| samples = remaining_samples |
| else: |
| samples = int(batch_size * proportion) |
| remaining_samples -= samples |
| generator_samples[generator] = samples |
|
|
| return generator_samples |
|
|
| def create_batch( |
| self, |
| batch_size: int = 128, |
| seed: int | None = None, |
| future_length: int | None = None, |
| ) -> tuple[BatchTimeSeriesContainer, str]: |
| """ |
| Create a batch of the specified size. |
| |
| Args: |
| batch_size: Size of the batch to create |
| seed: Random seed for this batch |
| future_length: Fixed future length to use. If None, samples from gift_eval range |
| |
| Returns: |
| Tuple of (batch_container, generator_info) |
| """ |
| if seed is not None: |
| batch_rng = np.random.default_rng(seed) |
| random.seed(seed) |
| else: |
| batch_rng = self.rng |
|
|
| if self.mixed_batches: |
| return self._create_mixed_batch(batch_size, future_length) |
| else: |
| return self._create_uniform_batch(batch_size, batch_rng, future_length) |
|
|
| def _create_mixed_batch( |
| self, batch_size: int, future_length: int | None = None |
| ) -> tuple[BatchTimeSeriesContainer, str]: |
| """Create a mixed batch with samples from multiple generators, rejecting NaNs.""" |
|
|
| |
| |
| if self.augmentations.get("length_shortening", False): |
| lengths = list(LENGTH_WEIGHTS.keys()) |
| probs = list(LENGTH_WEIGHTS.values()) |
| total_length_for_batch = int(self.rng.choice(lengths, p=probs)) |
| else: |
| total_length_for_batch = int(max(LENGTH_CHOICES)) |
|
|
| if future_length is None: |
| prediction_length = int(sample_future_length(range="gift_eval", total_length=total_length_for_batch)) |
| else: |
| prediction_length = future_length |
|
|
| history_length = total_length_for_batch - prediction_length |
|
|
| |
| effective_props = self._effective_proportions_for_length(total_length_for_batch) |
| generator_samples = self._compute_sample_counts_for_batch(effective_props, batch_size) |
|
|
| all_values = [] |
| all_starts = [] |
| all_frequencies = [] |
| actual_proportions = {} |
|
|
| |
| for generator_name, num_samples in generator_samples.items(): |
| if num_samples == 0 or generator_name not in self.datasets: |
| continue |
|
|
| dataset = self.datasets[generator_name] |
|
|
| |
| generator_values = [] |
| generator_starts = [] |
| generator_frequencies = [] |
|
|
| |
| max_attempts = 50 |
| attempts = 0 |
| while len(generator_values) < num_samples and attempts < max_attempts: |
| attempts += 1 |
| |
| need = num_samples - len(generator_values) |
| fetch_n = max(need * 2, 8) |
| samples = dataset.get_samples(fetch_n) |
|
|
| for sample in samples: |
| if len(generator_values) >= num_samples: |
| break |
|
|
| values, sample_start, sample_freq = self._convert_sample_to_tensors(sample, future_length) |
|
|
| |
| if torch.isnan(values).any(): |
| continue |
|
|
| |
| if total_length_for_batch < values.shape[1]: |
| strategy = self.rng.choice(["cut", "subsample"]) |
| if strategy == "cut": |
| max_start_idx = values.shape[1] - total_length_for_batch |
| start_idx = int(self.rng.integers(0, max_start_idx + 1)) |
| values = values[:, start_idx : start_idx + total_length_for_batch, :] |
| else: |
| indices = np.linspace( |
| 0, |
| values.shape[1] - 1, |
| total_length_for_batch, |
| dtype=int, |
| ) |
| values = values[:, indices, :] |
|
|
| |
| if self._should_apply_scaler_augmentation(): |
| scaler = self._choose_random_scaler() |
| if scaler is not None: |
| values = scaler.scale(values, scaler.compute_statistics(values)) |
|
|
| generator_values.append(values) |
| generator_starts.append(sample_start) |
| generator_frequencies.append(sample_freq) |
|
|
| if len(generator_values) < num_samples: |
| logger.warning( |
| f"Generator {generator_name}: collected {len(generator_values)}/" |
| f"{num_samples} after {attempts} attempts" |
| ) |
|
|
| |
| if generator_values: |
| all_values.extend(generator_values) |
| all_starts.extend(generator_starts) |
| all_frequencies.extend(generator_frequencies) |
| actual_proportions[generator_name] = len(generator_values) |
|
|
| if not all_values: |
| raise RuntimeError("No valid samples could be collected from any generator.") |
|
|
| combined_values = torch.cat(all_values, dim=0) |
| |
| combined_history = combined_values[:, :history_length, :] |
| combined_future = combined_values[:, history_length : history_length + prediction_length, :] |
|
|
| if self.nan_augmenter is not None: |
| combined_history = self.nan_augmenter.transform(combined_history) |
|
|
| |
| container = BatchTimeSeriesContainer( |
| history_values=combined_history, |
| future_values=combined_future, |
| start=all_starts, |
| frequency=all_frequencies, |
| ) |
|
|
| return container, "MixedBatch" |
|
|
| def _create_uniform_batch( |
| self, |
| batch_size: int, |
| batch_rng: np.random.Generator, |
| future_length: int | None = None, |
| ) -> tuple[BatchTimeSeriesContainer, str]: |
| """Create a uniform batch with samples from a single generator.""" |
|
|
| |
| generators = list(self.datasets.keys()) |
| proportions = [self.generator_proportions[gen] for gen in generators] |
| selected_generator = batch_rng.choice(generators, p=proportions) |
|
|
| |
| if future_length is None: |
| future_length = sample_future_length(range="gift_eval") |
|
|
| |
| dataset = self.datasets[selected_generator] |
| samples = dataset.get_samples(batch_size) |
|
|
| all_history_values = [] |
| all_future_values = [] |
| all_starts = [] |
| all_frequencies = [] |
|
|
| for sample in samples: |
| values, sample_start, sample_freq = self._convert_sample_to_tensors(sample, future_length) |
|
|
| total_length = values.shape[1] |
| history_length = max(1, total_length - future_length) |
|
|
| |
| if self._should_apply_scaler_augmentation(): |
| scaler = self._choose_random_scaler() |
| if scaler is not None: |
| values = scaler.scale(values, scaler.compute_statistics(values)) |
|
|
| |
| hist_vals = values[:, :history_length, :] |
| fut_vals = values[:, history_length : history_length + future_length, :] |
|
|
| all_history_values.append(hist_vals) |
| all_future_values.append(fut_vals) |
| all_starts.append(sample_start) |
| all_frequencies.append(sample_freq) |
|
|
| |
| combined_history = torch.cat(all_history_values, dim=0) |
| combined_future = torch.cat(all_future_values, dim=0) |
|
|
| |
| container = BatchTimeSeriesContainer( |
| history_values=combined_history, |
| future_values=combined_future, |
| start=all_starts, |
| frequency=all_frequencies, |
| ) |
|
|
| return container, selected_generator |
|
|
| def get_dataset_info(self) -> dict[str, dict]: |
| """Get information about all datasets.""" |
| info = {} |
| for name, dataset in self.datasets.items(): |
| info[name] = dataset.get_info() |
| return info |
|
|
| def get_generator_info(self) -> dict[str, any]: |
| """Get information about the composer configuration.""" |
| return { |
| "mixed_batches": self.mixed_batches, |
| "generator_proportions": self.generator_proportions, |
| "active_generators": list(self.datasets.keys()), |
| "total_generators": len(self.datasets), |
| "augmentations": self.augmentations, |
| "augmentation_probabilities": self.augmentation_probabilities, |
| "nan_augmenter_enabled": self.nan_augmenter is not None, |
| } |
|
|
|
|
| class ComposedDataset(torch.utils.data.Dataset): |
| """ |
| PyTorch Dataset wrapper around BatchComposer for training pipeline integration. |
| """ |
|
|
| def __init__( |
| self, |
| batch_composer: BatchComposer, |
| num_batches_per_epoch: int = 100, |
| batch_size: int = 128, |
| ): |
| """ |
| Initialize the dataset. |
| |
| Args: |
| batch_composer: The BatchComposer instance |
| num_batches_per_epoch: Number of batches to generate per epoch |
| batch_size: Size of each batch |
| """ |
| self.batch_composer = batch_composer |
| self.num_batches_per_epoch = num_batches_per_epoch |
| self.batch_size = batch_size |
|
|
| def __len__(self) -> int: |
| return self.num_batches_per_epoch |
|
|
| def __getitem__(self, idx: int) -> BatchTimeSeriesContainer: |
| """ |
| Get a batch by index. |
| |
| Args: |
| idx: Batch index (used as seed for reproducibility) |
| |
| Returns: |
| BatchTimeSeriesContainer |
| """ |
| |
| batch, _ = self.batch_composer.create_batch( |
| batch_size=self.batch_size, seed=self.batch_composer.global_seed + idx |
| ) |
| return batch |
|
|