| from datetime import datetime, timedelta
|
| import os
|
| import pickle
|
|
|
| from numba import njit, prange, types
|
| from numba.typed import Dict
|
| import numpy as np
|
| from torch.utils.data import Dataset, IterableDataset
|
|
|
| from .patches import unpack_patches
|
| from .sampling import EqualFrequencySampler
|
|
|
|
|
| class BatchGenerator:
|
| def __init__(self,
|
| variables,
|
| raw,
|
| predictors,
|
| target,
|
| primary_var,
|
| time_range_sampling=(-1,2),
|
| forecast_raw_vars=(),
|
| sampling_bins=None,
|
| sampler_file=None,
|
| sample_shape=(4,4),
|
| batch_size=32,
|
| interval=timedelta(minutes=5),
|
| random_seed=None,
|
| augment=False
|
| ):
|
| super().__init__()
|
| self.batch_size = batch_size
|
| self.interval = interval
|
| self.interval_secs = np.int64(self.interval.total_seconds())
|
| self.variables = variables
|
| self.predictors = predictors
|
| self.target = target
|
| self.used_variables = predictors + [target]
|
| self.rng = np.random.RandomState(seed=random_seed)
|
| self.augment = augment
|
|
|
|
|
| self.sources = set.union(
|
| *(set(variables[v]["sources"]) for v in self.used_variables)
|
| )
|
| self.forecast_raw_vars = set(forecast_raw_vars) & self.sources
|
| self.patch_index = {}
|
| for raw_name_base in self.sources:
|
| if raw_name_base in forecast_raw_vars:
|
| raw_names = (
|
| rn for rn in raw if rn.startswith(raw_name_base+"-")
|
| )
|
| else:
|
| raw_names = (raw_name_base,)
|
| for raw_name in raw_names:
|
| raw_data = raw[raw_name]
|
| self.setup_index(raw_name, raw_data, sample_shape)
|
|
|
| for raw_name in self.forecast_raw_vars:
|
| patch_index_var = {
|
| k: v for (k,v) in self.patch_index.items()
|
| if k.startswith(raw_name+"-")
|
| }
|
| self.patch_index[raw_name] = \
|
| ForecastPatchIndexWrapper(patch_index_var)
|
|
|
|
|
| if (sampler_file is None) or not os.path.isfile(sampler_file):
|
| print("No cached sampler found, creating a new one...")
|
| primary_raw_var = variables[primary_var]["sources"][0]
|
| t0 = t1 = None
|
| for (var_name, var_data) in variables.items():
|
| timesteps = var_data["timesteps"][[0,-1]].copy()
|
| timesteps[0] -= 1
|
| ts_secs = timesteps * \
|
| var_data.get("timestep_secs", self.interval_secs)
|
| timesteps = ts_secs // self.interval_secs
|
| t0 = timesteps[0] if t0 is None else min(t0,timesteps[0])
|
| t1 = timesteps[-1] if t1 is None else max(t1,timesteps[-1])
|
| time_range_valid = (t0,t1+1)
|
| self.sampler = EqualFrequencySampler(
|
| sampling_bins, raw[primary_raw_var],
|
| self.patch_index[primary_raw_var], sample_shape,
|
| time_range_valid, time_range_sampling=time_range_sampling,
|
| timestep_secs=self.interval_secs
|
| )
|
| if sampler_file is not None:
|
| print(f"Caching sampler to {sampler_file}.")
|
| with open(sampler_file, 'wb') as f:
|
| pickle.dump(self.sampler, f)
|
| else:
|
| print(f"Loading cached sampler from {sampler_file}.")
|
| with open(sampler_file, 'rb') as f:
|
| self.sampler = pickle.load(f)
|
|
|
| def setup_index(self, raw_name, raw_data, box_size):
|
| zero_value = raw_data.get("zero_value", 0)
|
| missing_value = raw_data.get("missing_value", zero_value)
|
|
|
| self.patch_index[raw_name] = PatchIndex(
|
| *unpack_patches(raw_data),
|
| zero_value=zero_value,
|
| missing_value=missing_value,
|
| interval=self.interval,
|
| box_size=box_size
|
| )
|
|
|
| def augmentations(self):
|
| return tuple(self.rng.randint(2, size=3))
|
|
|
| def augment_batch(self, batch, transpose, flipud, fliplr):
|
| if self.augment:
|
| if transpose:
|
| axes = list(range(batch.ndim))
|
| axes = axes[:-2] + [axes[-1], axes[-2]]
|
| batch = batch.transpose(axes)
|
| flips = []
|
| if flipud:
|
| flips.append(-2)
|
| if fliplr:
|
| flips.append(-1)
|
| if flips:
|
| batch = np.flip(batch, axis=flips)
|
| return batch.copy()
|
|
|
| def batch(self, samples=None, batch_size=None):
|
| if batch_size is None:
|
| batch_size = self.batch_size
|
|
|
| if samples is None:
|
|
|
| samples = self.sampler(batch_size)
|
|
|
| print(samples)
|
| (t0,i0,j0) = samples.T
|
|
|
| if self.augment:
|
| augmentations = self.augmentations()
|
|
|
| batch = {}
|
| for var_name in self.used_variables:
|
| var_data = self.variables[var_name]
|
|
|
|
|
|
|
| ts_secs = var_data.get("timestep_secs", self.interval_secs)
|
| t_shift = -(t0 % ts_secs)
|
| t0_shifted = t0 + t_shift
|
| t = t0_shifted[:,None] + ts_secs*var_data["timesteps"][None,:]
|
| t_relative = (t - t0[:,None]) / self.interval_secs
|
|
|
|
|
| raw_data = (
|
| self.patch_index[raw_name](t,i0,j0)
|
| for raw_name in var_data["sources"]
|
| )
|
|
|
|
|
| batch_var = var_data["transform"](*raw_data)
|
|
|
|
|
| add_dims = (1,) if batch_var.ndim == 4 else ()
|
| batch_var = np.expand_dims(batch_var, add_dims)
|
|
|
|
|
| if self.augment:
|
| batch_var = self.augment_batch(batch_var, *augmentations)
|
|
|
|
|
| batch[var_name] = (batch_var, t_relative.astype(np.float32))
|
|
|
| pred_batch = [batch[v] for v in self.predictors]
|
| target_batch = batch[self.target][0]
|
| return (pred_batch, target_batch)
|
|
|
| def batches(self, *args, num=None, **kwargs):
|
| if num is not None:
|
| for i in range(num):
|
| yield self.batch(*args, **kwargs)
|
| else:
|
| while True:
|
| yield self.batch(*args, **kwargs)
|
|
|
|
|
| class StreamBatchDataset(IterableDataset):
|
| def __init__(self, batch_gen, batches_per_epoch):
|
| super().__init__()
|
| self.batch_gen = batch_gen
|
| self.batches_per_epoch = batches_per_epoch
|
|
|
| def __iter__(self):
|
| batches = self.batch_gen.batches(num=self.batches_per_epoch)
|
| yield from batches
|
|
|
|
|
| class DeterministicBatchDataset(Dataset):
|
| def __init__(self, batch_gen, batches_per_epoch, random_seed=None):
|
| super().__init__()
|
| self.batch_gen = batch_gen
|
| self.batches_per_epoch = batches_per_epoch
|
| self.batch_gen.sampler.rng = np.random.RandomState(seed=random_seed)
|
| self.samples = [
|
| self.batch_gen.sampler(self.batch_gen.batch_size)
|
| for i in range(self.batches_per_epoch)
|
| ]
|
|
|
| def __len__(self):
|
| return self.batches_per_epoch
|
|
|
| def __getitem__(self, ind):
|
| print(self.samples[ind])
|
| return self.batch_gen.batch(samples=self.samples[ind])
|
|
|
|
|
| class PatchIndex:
|
| IDX_ZERO = -1
|
| IDX_MISSING = -2
|
|
|
| def __init__(
|
| self, patch_data, patch_coords, patch_times,
|
| zero_patch_coords, zero_patch_times,
|
| interval=timedelta(minutes=5),
|
| box_size=(4,4), zero_value=0,
|
| missing_value=0
|
| ):
|
| self.dt = int(round(interval.total_seconds()))
|
| self.box_size = box_size
|
| self.zero_value = zero_value
|
| self.missing_value = missing_value
|
| self.patch_data = patch_data
|
| self.sample_shape = (
|
| self.patch_data.shape[1]*box_size[0],
|
| self.patch_data.shape[2]*box_size[1]
|
| )
|
|
|
| self.patch_index = Dict.empty(
|
| key_type=types.UniTuple(types.int64, 3),
|
| value_type=types.int64
|
| )
|
| init_patch_index(self.patch_index, patch_coords, patch_times)
|
| init_patch_index_zero(self.patch_index, zero_patch_coords,
|
| zero_patch_times, PatchIndex.IDX_ZERO)
|
|
|
| self._batch = None
|
|
|
| def _alloc_batch(self, batch_size, num_timesteps):
|
| needs_rebuild = (self._batch is None) or \
|
| (self._batch.shape[0] < batch_size) or \
|
| (self._batch.shape[1] < num_timesteps)
|
| if needs_rebuild:
|
| del self._batch
|
| self._batch = np.zeros(
|
| (batch_size,num_timesteps)+self.sample_shape,
|
| self.patch_data.dtype
|
| )
|
| return self._batch
|
|
|
| def __call__(self, t, i0_all, j0_all):
|
| batch = self._alloc_batch(*t.shape)
|
|
|
| i1_all = i0_all + self.box_size[0]
|
| j1_all = j0_all + self.box_size[1]
|
| bi_size = self.patch_data.shape[1]
|
| bj_size = self.patch_data.shape[2]
|
|
|
| build_batch(batch, self.patch_data, self.patch_index,
|
| t, i0_all, i1_all, j0_all, j1_all,
|
| bi_size, bj_size, self.zero_value,
|
| self.missing_value)
|
|
|
| return batch[:,:t.shape[1],...]
|
|
|
|
|
| @njit
|
| def init_patch_index(patch_index, patch_coords, patch_times):
|
| for k in range(patch_coords.shape[0]):
|
| t = patch_times[k]
|
| i = np.int64(patch_coords[k,0])
|
| j = np.int64(patch_coords[k,1])
|
| patch_index[(t,i,j)] = k
|
|
|
|
|
| @njit
|
| def init_patch_index_zero(patch_index, zero_patch_coords,
|
| zero_patch_times, idx_zero):
|
|
|
| for k in range(zero_patch_coords.shape[0]):
|
| t = zero_patch_times[k]
|
| i = np.int64(zero_patch_coords[k,0])
|
| j = np.int64(zero_patch_coords[k,1])
|
| patch_index[(t,i,j)] = idx_zero
|
|
|
|
|
|
|
| IDX_ZERO = PatchIndex.IDX_ZERO
|
| IDX_MISSING = PatchIndex.IDX_MISSING
|
| @njit(parallel=True)
|
| def build_batch(
|
| batch, patch_data, patch_index,
|
| t_all, i0_all, i1_all, j0_all, j1_all,
|
| bi_size, bj_size, zero_value, missing_value
|
| ):
|
| for k in prange(t_all.shape[0]):
|
| i0 = i0_all[k]
|
| i1 = i1_all[k]
|
| j0 = j0_all[k]
|
| j1 = j1_all[k]
|
|
|
| for (bt,t) in enumerate(t_all[k,:]):
|
| for i in range(i0, i1):
|
| bi0 = (i-i0) * bi_size
|
| bi1 = bi0 + bi_size
|
| for j in range(j0, j1):
|
| ind = int(patch_index.get((t,i,j), IDX_MISSING))
|
| bj0 = (j-j0) * bj_size
|
| bj1 = bj0 + bj_size
|
| if ind >= 0:
|
| batch[k,bt,bi0:bi1,bj0:bj1] = patch_data[ind]
|
| elif ind == IDX_ZERO:
|
| batch[k,bt,bi0:bi1,bj0:bj1] = zero_value
|
| elif ind == IDX_MISSING:
|
| batch[k,bt,bi0:bi1,bj0:bj1] = missing_value
|
|
|
|
|
| class ForecastPatchIndexWrapper(PatchIndex):
|
| def __init__(self, patch_index):
|
| self.patch_index = patch_index
|
| raw_names = {"-".join(v.split("-")[:-1]) for v in patch_index}
|
| if len(raw_names) != 1:
|
| raise ValueError(
|
| "Can only wrap variables with the same base name")
|
| self.raw_name = list(raw_names)[0]
|
| lags_hour = [int(v.split("-")[-1]) for v in patch_index]
|
| self.lags_hour = set(lags_hour)
|
| forecast_interval_hour = np.diff(sorted(lags_hour))
|
| if len(set(forecast_interval_hour)) != 1:
|
| raise ValueError("Lags must be evenly spaced")
|
| forecast_interval_hour = forecast_interval_hour[0]
|
| if (24 % forecast_interval_hour):
|
| raise ValueError(
|
| "24 hours must be a multiple of the forecast interval")
|
| self.forecast_interval_hour = forecast_interval_hour
|
| self.forecast_interval = 3600 * forecast_interval_hour
|
|
|
|
|
| self._batch = None
|
| v = list(self.patch_index.keys())[0]
|
| self.sample_shape = self.patch_index[v].sample_shape
|
| self.patch_data = self.patch_index[v].patch_data
|
|
|
| def __call__(self, t, i0, j0):
|
| batch = self._alloc_batch(*t.shape)
|
|
|
|
|
| t0 = t[:,:1]
|
| start_time_from_fc = t0 % self.forecast_interval
|
| time_from_fc = start_time_from_fc + (t - t0)
|
| lags_hour = (time_from_fc // self.forecast_interval) * \
|
| self.forecast_interval_hour
|
|
|
| for lag in self.lags_hour:
|
| raw_name_lag = f"{self.raw_name}-{lag}"
|
| batch_lag = self.patch_index[raw_name_lag](t,i0,j0)
|
| lag_mask = (lags_hour == lag)
|
| copy_masked_times(batch_lag, batch, lag_mask)
|
|
|
| return batch[:,:t.shape[1],...]
|
|
|
|
|
| @njit(parallel=True)
|
| def copy_masked_times(from_batch, to_batch, mask):
|
| for k in prange(from_batch.shape[0]):
|
| for bt in range(from_batch.shape[1]):
|
| if mask[k,bt]:
|
| to_batch[k,bt,:,:] = from_batch[k,bt,:,:]
|
|
|