| import logging |
| from typing import Any |
|
|
| import numpy as np |
| import pandas as pd |
| import scipy.fft as fft |
| import torch |
| from gluonts.time_feature import time_features_from_frequency_str |
| from gluonts.time_feature._base import ( |
| day_of_month, |
| day_of_month_index, |
| day_of_week, |
| day_of_week_index, |
| day_of_year, |
| hour_of_day, |
| hour_of_day_index, |
| minute_of_hour, |
| minute_of_hour_index, |
| month_of_year, |
| month_of_year_index, |
| second_of_minute, |
| second_of_minute_index, |
| week_of_year, |
| week_of_year_index, |
| ) |
| from gluonts.time_feature.holiday import ( |
| BLACK_FRIDAY, |
| CHRISTMAS_DAY, |
| CHRISTMAS_EVE, |
| CYBER_MONDAY, |
| EASTER_MONDAY, |
| EASTER_SUNDAY, |
| GOOD_FRIDAY, |
| INDEPENDENCE_DAY, |
| LABOR_DAY, |
| MEMORIAL_DAY, |
| NEW_YEARS_DAY, |
| NEW_YEARS_EVE, |
| THANKSGIVING, |
| SpecialDateFeatureSet, |
| exponential_kernel, |
| squared_exponential_kernel, |
| ) |
| from gluonts.time_feature.seasonality import get_seasonality |
| from scipy.signal import find_peaks |
|
|
| from src.data.constants import BASE_END_DATE, BASE_START_DATE |
| from src.data.frequency import ( |
| Frequency, |
| validate_frequency_safety, |
| ) |
| from src.utils.utils import device |
|
|
| |
| logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| ENHANCED_TIME_FEATURES = { |
| |
| "high_freq": { |
| "normalized": [ |
| second_of_minute, |
| minute_of_hour, |
| hour_of_day, |
| day_of_week, |
| day_of_month, |
| ], |
| "index": [ |
| second_of_minute_index, |
| minute_of_hour_index, |
| hour_of_day_index, |
| day_of_week_index, |
| ], |
| }, |
| |
| "medium_freq": { |
| "normalized": [ |
| hour_of_day, |
| day_of_week, |
| day_of_month, |
| day_of_year, |
| month_of_year, |
| ], |
| "index": [ |
| hour_of_day_index, |
| day_of_week_index, |
| day_of_month_index, |
| week_of_year_index, |
| ], |
| }, |
| |
| "low_freq": { |
| "normalized": [day_of_week, day_of_month, month_of_year, week_of_year], |
| "index": [day_of_week_index, month_of_year_index, week_of_year_index], |
| }, |
| } |
|
|
| |
| HOLIDAY_FEATURE_SETS = { |
| "us_business": [ |
| NEW_YEARS_DAY, |
| MEMORIAL_DAY, |
| INDEPENDENCE_DAY, |
| LABOR_DAY, |
| THANKSGIVING, |
| CHRISTMAS_EVE, |
| CHRISTMAS_DAY, |
| NEW_YEARS_EVE, |
| ], |
| "us_retail": [ |
| NEW_YEARS_DAY, |
| EASTER_SUNDAY, |
| MEMORIAL_DAY, |
| INDEPENDENCE_DAY, |
| LABOR_DAY, |
| THANKSGIVING, |
| BLACK_FRIDAY, |
| CYBER_MONDAY, |
| CHRISTMAS_EVE, |
| CHRISTMAS_DAY, |
| NEW_YEARS_EVE, |
| ], |
| "christian": [ |
| NEW_YEARS_DAY, |
| GOOD_FRIDAY, |
| EASTER_SUNDAY, |
| EASTER_MONDAY, |
| CHRISTMAS_EVE, |
| CHRISTMAS_DAY, |
| NEW_YEARS_EVE, |
| ], |
| } |
|
|
|
|
| class TimeFeatureGenerator: |
| """ |
| Enhanced time feature generator that leverages full GluonTS capabilities. |
| """ |
|
|
| def __init__( |
| self, |
| use_enhanced_features: bool = True, |
| use_holiday_features: bool = True, |
| holiday_set: str = "us_business", |
| holiday_kernel: str = "exponential", |
| holiday_kernel_alpha: float = 1.0, |
| use_index_features: bool = True, |
| k_max: int = 15, |
| include_seasonality_info: bool = True, |
| use_auto_seasonality: bool = False, |
| max_seasonal_periods: int = 3, |
| ): |
| """ |
| Initialize enhanced time feature generator. |
| |
| Parameters |
| ---------- |
| use_enhanced_features : bool |
| Whether to use frequency-specific enhanced features |
| use_holiday_features : bool |
| Whether to include holiday features |
| holiday_set : str |
| Which holiday set to use ('us_business', 'us_retail', 'christian') |
| holiday_kernel : str |
| Holiday kernel type ('indicator', 'exponential', 'squared_exponential') |
| holiday_kernel_alpha : float |
| Kernel parameter for exponential kernels |
| use_index_features : bool |
| Whether to include index-based features alongside normalized ones |
| k_max : int |
| Maximum number of time features to pad to |
| include_seasonality_info : bool |
| Whether to include seasonality information as features |
| use_auto_seasonality : bool |
| Whether to use automatic FFT-based seasonality detection |
| max_seasonal_periods : int |
| Maximum number of seasonal periods to detect automatically |
| """ |
| self.use_enhanced_features = use_enhanced_features |
| self.use_holiday_features = use_holiday_features |
| self.holiday_set = holiday_set |
| self.use_index_features = use_index_features |
| self.k_max = k_max |
| self.include_seasonality_info = include_seasonality_info |
| self.use_auto_seasonality = use_auto_seasonality |
| self.max_seasonal_periods = max_seasonal_periods |
|
|
| |
| self.holiday_feature_set = None |
| if use_holiday_features and holiday_set in HOLIDAY_FEATURE_SETS: |
| kernel_func = self._get_holiday_kernel(holiday_kernel, holiday_kernel_alpha) |
| self.holiday_feature_set = SpecialDateFeatureSet(HOLIDAY_FEATURE_SETS[holiday_set], kernel_func) |
|
|
| def _get_holiday_kernel(self, kernel_type: str, alpha: float): |
| """Get holiday kernel function.""" |
| if kernel_type == "exponential": |
| return exponential_kernel(alpha) |
| elif kernel_type == "squared_exponential": |
| return squared_exponential_kernel(alpha) |
| else: |
| |
| return lambda x: float(x == 0) |
|
|
| def _get_feature_category(self, freq_str: str) -> str: |
| """Determine feature category based on frequency.""" |
| if freq_str in ["s", "1min", "5min", "10min", "15min"]: |
| return "high_freq" |
| elif freq_str in ["h", "D"]: |
| return "medium_freq" |
| else: |
| return "low_freq" |
|
|
| def _compute_enhanced_features(self, period_index: pd.PeriodIndex, freq_str: str) -> np.ndarray: |
| """Compute enhanced time features based on frequency.""" |
| if not self.use_enhanced_features: |
| return np.array([]).reshape(len(period_index), 0) |
|
|
| category = self._get_feature_category(freq_str) |
| feature_config = ENHANCED_TIME_FEATURES[category] |
|
|
| features = [] |
|
|
| |
| for feat_func in feature_config["normalized"]: |
| try: |
| feat_values = feat_func(period_index) |
| features.append(feat_values) |
| except Exception: |
| continue |
|
|
| |
| if self.use_index_features: |
| for feat_func in feature_config["index"]: |
| try: |
| feat_values = feat_func(period_index) |
| |
| if feat_values.max() > 0: |
| feat_values = feat_values / feat_values.max() |
| features.append(feat_values) |
| except Exception: |
| continue |
|
|
| if features: |
| return np.stack(features, axis=-1) |
| else: |
| return np.array([]).reshape(len(period_index), 0) |
|
|
| def _compute_holiday_features(self, date_range: pd.DatetimeIndex) -> np.ndarray: |
| """Compute holiday features.""" |
| if not self.use_holiday_features or self.holiday_feature_set is None: |
| return np.array([]).reshape(len(date_range), 0) |
|
|
| try: |
| holiday_features = self.holiday_feature_set(date_range) |
| return holiday_features.T |
| except Exception: |
| return np.array([]).reshape(len(date_range), 0) |
|
|
| def _detect_auto_seasonality(self, time_series_values: np.ndarray) -> list: |
| """ |
| Detect seasonal periods automatically using FFT analysis. |
| |
| Parameters |
| ---------- |
| time_series_values : np.ndarray |
| Time series values for seasonality detection |
| |
| Returns |
| ------- |
| list |
| List of detected seasonal periods |
| """ |
| if not self.use_auto_seasonality or len(time_series_values) < 10: |
| return [] |
|
|
| try: |
| |
| values = time_series_values[~np.isnan(time_series_values)] |
| if len(values) < 10: |
| return [] |
|
|
| |
| x = np.arange(len(values)) |
| coeffs = np.polyfit(x, values, 1) |
| trend = np.polyval(coeffs, x) |
| detrended = values - trend |
|
|
| |
| window = np.hanning(len(detrended)) |
| windowed = detrended * window |
|
|
| |
| padded_length = len(windowed) * 2 |
| padded_values = np.zeros(padded_length) |
| padded_values[: len(windowed)] = windowed |
|
|
| |
| fft_values = fft.rfft(padded_values) |
| fft_magnitudes = np.abs(fft_values) |
| freqs = np.fft.rfftfreq(padded_length) |
|
|
| |
| fft_magnitudes[0] = 0.0 |
|
|
| |
| threshold = 0.05 * np.max(fft_magnitudes) |
| peak_indices, _ = find_peaks(fft_magnitudes, height=threshold) |
|
|
| if len(peak_indices) == 0: |
| return [] |
|
|
| |
| sorted_indices = peak_indices[np.argsort(fft_magnitudes[peak_indices])[::-1]] |
| top_indices = sorted_indices[: self.max_seasonal_periods] |
|
|
| |
| periods = [] |
| for idx in top_indices: |
| if freqs[idx] > 0: |
| period = 1.0 / freqs[idx] |
| |
| period = round(period / 2) |
| if 2 <= period <= len(values) // 2: |
| periods.append(period) |
|
|
| return list(set(periods)) |
|
|
| except Exception: |
| return [] |
|
|
| def _compute_seasonality_features( |
| self, |
| period_index: pd.PeriodIndex, |
| freq_str: str, |
| time_series_values: np.ndarray = None, |
| ) -> np.ndarray: |
| """Compute seasonality-aware features.""" |
| if not self.include_seasonality_info: |
| return np.array([]).reshape(len(period_index), 0) |
|
|
| all_seasonal_features = [] |
|
|
| |
| try: |
| seasonality = get_seasonality(freq_str) |
| if seasonality > 1: |
| positions = np.arange(len(period_index)) |
| sin_feat = np.sin(2 * np.pi * positions / seasonality) |
| cos_feat = np.cos(2 * np.pi * positions / seasonality) |
| all_seasonal_features.extend([sin_feat, cos_feat]) |
| except Exception: |
| pass |
|
|
| |
| if self.use_auto_seasonality and time_series_values is not None: |
| auto_periods = self._detect_auto_seasonality(time_series_values) |
| for period in auto_periods: |
| try: |
| positions = np.arange(len(period_index)) |
| sin_feat = np.sin(2 * np.pi * positions / period) |
| cos_feat = np.cos(2 * np.pi * positions / period) |
| all_seasonal_features.extend([sin_feat, cos_feat]) |
| except Exception: |
| continue |
|
|
| if all_seasonal_features: |
| return np.stack(all_seasonal_features, axis=-1) |
| else: |
| return np.array([]).reshape(len(period_index), 0) |
|
|
| def compute_features( |
| self, |
| period_index: pd.PeriodIndex, |
| date_range: pd.DatetimeIndex, |
| freq_str: str, |
| time_series_values: np.ndarray = None, |
| ) -> np.ndarray: |
| """ |
| Compute all time features for given period index. |
| |
| Parameters |
| ---------- |
| period_index : pd.PeriodIndex |
| Period index for computing features |
| date_range : pd.DatetimeIndex |
| Corresponding datetime index for holiday features |
| freq_str : str |
| Frequency string |
| time_series_values : np.ndarray, optional |
| Time series values for automatic seasonality detection |
| |
| Returns |
| ------- |
| np.ndarray |
| Time features array of shape [time_steps, num_features] |
| """ |
| all_features = [] |
|
|
| |
| try: |
| standard_features = time_features_from_frequency_str(freq_str) |
| if standard_features: |
| std_feat = np.stack([feat(period_index) for feat in standard_features], axis=-1) |
| all_features.append(std_feat) |
| except Exception: |
| pass |
|
|
| |
| enhanced_feat = self._compute_enhanced_features(period_index, freq_str) |
| if enhanced_feat.shape[1] > 0: |
| all_features.append(enhanced_feat) |
|
|
| |
| holiday_feat = self._compute_holiday_features(date_range) |
| if holiday_feat.shape[1] > 0: |
| all_features.append(holiday_feat) |
|
|
| |
| seasonality_feat = self._compute_seasonality_features(period_index, freq_str, time_series_values) |
| if seasonality_feat.shape[1] > 0: |
| all_features.append(seasonality_feat) |
|
|
| if all_features: |
| combined_features = np.concatenate(all_features, axis=-1) |
| else: |
| combined_features = np.zeros((len(period_index), 1)) |
|
|
| return combined_features |
|
|
|
|
| def compute_batch_time_features( |
| start: list[np.datetime64], |
| history_length: int, |
| future_length: int, |
| batch_size: int, |
| frequency: list[Frequency], |
| K_max: int = 6, |
| time_feature_config: dict[str, Any] | None = None, |
| ): |
| """ |
| Compute time features from start timestamps and frequency. |
| |
| Parameters |
| ---------- |
| start : array-like, shape (batch_size,) |
| Start timestamps for each batch item. |
| history_length : int |
| Length of history sequence. |
| future_length : int |
| Length of target sequence. |
| batch_size : int |
| Batch size. |
| frequency : array-like, shape (batch_size,) |
| Frequency of the time series. |
| K_max : int, optional |
| Maximum number of time features to pad to (default: 6). |
| time_feature_config : dict, optional |
| Configuration for enhanced time features. |
| |
| Returns |
| ------- |
| tuple |
| (history_time_features, target_time_features) where each is a torch.Tensor |
| of shape (batch_size, length, K_max). |
| """ |
| |
| feature_config = time_feature_config or {} |
| feature_generator = TimeFeatureGenerator(**feature_config) |
|
|
| |
| history_features_list = [] |
| future_features_list = [] |
| total_length = history_length + future_length |
| for i in range(batch_size): |
| frequency_i = frequency[i] |
| freq_str = frequency_i.to_pandas_freq(for_date_range=True) |
| period_freq_str = frequency_i.to_pandas_freq(for_date_range=False) |
|
|
| |
| start_ts = pd.Timestamp(start[i]) |
| if not validate_frequency_safety(start_ts, total_length, frequency_i): |
| logger.debug( |
| f"Start date {start_ts} not safe for total_length={total_length}, frequency={frequency_i}. " |
| f"Using BASE_START_DATE instead." |
| ) |
| start_ts = BASE_START_DATE |
|
|
| |
| history_range = pd.date_range(start=start_ts, periods=history_length, freq=freq_str) |
|
|
| |
| if history_range[-1] > BASE_END_DATE: |
| safe_start = BASE_END_DATE - pd.tseries.frequencies.to_offset(freq_str) * (history_length + future_length) |
| if safe_start < BASE_START_DATE: |
| safe_start = BASE_START_DATE |
| history_range = pd.date_range(start=safe_start, periods=history_length, freq=freq_str) |
|
|
| future_start = history_range[-1] + pd.tseries.frequencies.to_offset(freq_str) |
| future_range = pd.date_range(start=future_start, periods=future_length, freq=freq_str) |
|
|
| |
| history_period_idx = history_range.to_period(period_freq_str) |
| future_period_idx = future_range.to_period(period_freq_str) |
|
|
| |
| history_features = feature_generator.compute_features(history_period_idx, history_range, freq_str) |
| future_features = feature_generator.compute_features(future_period_idx, future_range, freq_str) |
|
|
| |
| history_features = _pad_or_truncate_features(history_features, K_max) |
| future_features = _pad_or_truncate_features(future_features, K_max) |
|
|
| history_features_list.append(history_features) |
| future_features_list.append(future_features) |
|
|
| |
| history_time_features = np.stack(history_features_list, axis=0) |
| future_time_features = np.stack(future_features_list, axis=0) |
|
|
| return ( |
| torch.from_numpy(history_time_features).float().to(device), |
| torch.from_numpy(future_time_features).float().to(device), |
| ) |
|
|
|
|
| def _pad_or_truncate_features(features: np.ndarray, K_max: int) -> np.ndarray: |
| """Pad with zeros or truncate features to K_max dimensions.""" |
| seq_len, num_features = features.shape |
|
|
| if num_features < K_max: |
| |
| padding = np.zeros((seq_len, K_max - num_features)) |
| features = np.concatenate([features, padding], axis=-1) |
| elif num_features > K_max: |
| |
| features = features[:, :K_max] |
|
|
| return features |
|
|