| | import logging |
| |
|
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | import torchmetrics |
| | from matplotlib.figure import Figure |
| |
|
| | from src.data.containers import BatchTimeSeriesContainer |
| | from src.data.frequency import Frequency |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def calculate_smape(y_true: np.ndarray, y_pred: np.ndarray) -> float: |
| | """Calculate Symmetric Mean Absolute Percentage Error (SMAPE).""" |
| | pred_tensor = torch.from_numpy(y_pred).float() |
| | true_tensor = torch.from_numpy(y_true).float() |
| | return torchmetrics.SymmetricMeanAbsolutePercentageError()(pred_tensor, true_tensor).item() |
| |
|
| |
|
| | def _create_date_ranges( |
| | start: np.datetime64 | pd.Timestamp | None, |
| | frequency: Frequency | str | None, |
| | history_length: int, |
| | prediction_length: int, |
| | ) -> tuple[pd.DatetimeIndex, pd.DatetimeIndex]: |
| | """Create date ranges for history and future periods.""" |
| | if start is not None and frequency is not None: |
| | start_timestamp = pd.Timestamp(start) |
| | pandas_freq = frequency.to_pandas_freq(for_date_range=True) |
| |
|
| | history_dates = pd.date_range(start=start_timestamp, periods=history_length, freq=pandas_freq) |
| |
|
| | if prediction_length > 0: |
| | next_timestamp = history_dates[-1] + pd.tseries.frequencies.to_offset(pandas_freq) |
| | future_dates = pd.date_range(start=next_timestamp, periods=prediction_length, freq=pandas_freq) |
| | else: |
| | future_dates = pd.DatetimeIndex([]) |
| | else: |
| | |
| | history_dates = pd.date_range(end=pd.Timestamp.now(), periods=history_length, freq="D") |
| |
|
| | if prediction_length > 0: |
| | future_dates = pd.date_range( |
| | start=history_dates[-1] + pd.Timedelta(days=1), |
| | periods=prediction_length, |
| | freq="D", |
| | ) |
| | else: |
| | future_dates = pd.DatetimeIndex([]) |
| |
|
| | return history_dates, future_dates |
| |
|
| |
|
| | def _plot_single_channel( |
| | ax: plt.Axes, |
| | channel_idx: int, |
| | history_dates: pd.DatetimeIndex, |
| | future_dates: pd.DatetimeIndex, |
| | history_values: np.ndarray, |
| | future_values: np.ndarray | None = None, |
| | predicted_values: np.ndarray | None = None, |
| | lower_bound: np.ndarray | None = None, |
| | upper_bound: np.ndarray | None = None, |
| | ) -> None: |
| | """Plot a single channel's time series data.""" |
| | |
| | ax.plot(history_dates, history_values[:, channel_idx], color="black", label="History") |
| |
|
| | |
| | if future_values is not None: |
| | ax.plot( |
| | future_dates, |
| | future_values[:, channel_idx], |
| | color="blue", |
| | label="Ground Truth", |
| | ) |
| |
|
| | |
| | if predicted_values is not None: |
| | ax.plot( |
| | future_dates, |
| | predicted_values[:, channel_idx], |
| | color="orange", |
| | linestyle="--", |
| | label="Prediction (Median)", |
| | ) |
| |
|
| | |
| | if lower_bound is not None and upper_bound is not None: |
| | ax.fill_between( |
| | future_dates, |
| | lower_bound[:, channel_idx], |
| | upper_bound[:, channel_idx], |
| | color="orange", |
| | alpha=0.2, |
| | label="Uncertainty Band", |
| | ) |
| |
|
| | ax.set_title(f"Channel {channel_idx + 1}") |
| | ax.grid(True, which="both", linestyle="--", linewidth=0.5) |
| |
|
| |
|
| | def _setup_figure(num_channels: int) -> tuple[Figure, list[plt.Axes]]: |
| | """Create and configure the matplotlib figure and axes.""" |
| | fig, axes = plt.subplots(num_channels, 1, figsize=(15, 3 * num_channels), sharex=True) |
| | if num_channels == 1: |
| | axes = [axes] |
| | return fig, axes |
| |
|
| |
|
| | def _finalize_plot( |
| | fig: Figure, |
| | axes: list[plt.Axes], |
| | title: str | None = None, |
| | smape_value: float | None = None, |
| | output_file: str | None = None, |
| | show: bool = True, |
| | ) -> None: |
| | """Add legend, title, and save/show the plot.""" |
| | |
| | handles, labels = axes[0].get_legend_handles_labels() |
| | fig.legend(handles, labels, loc="upper right") |
| |
|
| | |
| | if title: |
| | if smape_value is not None: |
| | title = f"{title} | SMAPE: {smape_value:.4f}" |
| | fig.suptitle(title, fontsize=16) |
| |
|
| | |
| | plt.tight_layout(rect=[0, 0.03, 1, 0.95] if title else None) |
| |
|
| | |
| | if output_file: |
| | plt.savefig(output_file, dpi=300) |
| | if show: |
| | plt.show() |
| | else: |
| | plt.close(fig) |
| |
|
| |
|
| | def plot_multivariate_timeseries( |
| | history_values: np.ndarray, |
| | future_values: np.ndarray | None = None, |
| | predicted_values: np.ndarray | None = None, |
| | start: np.datetime64 | pd.Timestamp | None = None, |
| | frequency: Frequency | str | None = None, |
| | title: str | None = None, |
| | output_file: str | None = None, |
| | show: bool = True, |
| | lower_bound: np.ndarray | None = None, |
| | upper_bound: np.ndarray | None = None, |
| | ) -> Figure: |
| | """Plot a multivariate time series with history, future, predictions, and uncertainty bands.""" |
| | |
| | smape_value = None |
| | if predicted_values is not None and future_values is not None: |
| | try: |
| | smape_value = calculate_smape(future_values, predicted_values) |
| | except Exception as e: |
| | logger.warning(f"Failed to calculate SMAPE: {str(e)}") |
| |
|
| | |
| | num_channels = history_values.shape[1] |
| | history_length = history_values.shape[0] |
| | prediction_length = ( |
| | predicted_values.shape[0] |
| | if predicted_values is not None |
| | else (future_values.shape[0] if future_values is not None else 0) |
| | ) |
| |
|
| | |
| | history_dates, future_dates = _create_date_ranges(start, frequency, history_length, prediction_length) |
| |
|
| | |
| | fig, axes = _setup_figure(num_channels) |
| |
|
| | |
| | for i in range(num_channels): |
| | _plot_single_channel( |
| | ax=axes[i], |
| | channel_idx=i, |
| | history_dates=history_dates, |
| | future_dates=future_dates, |
| | history_values=history_values, |
| | future_values=future_values, |
| | predicted_values=predicted_values, |
| | lower_bound=lower_bound, |
| | upper_bound=upper_bound, |
| | ) |
| |
|
| | |
| | _finalize_plot(fig, axes, title, smape_value, output_file, show) |
| |
|
| | return fig |
| |
|
| |
|
| | def _extract_quantile_predictions( |
| | predicted_values: np.ndarray, |
| | model_quantiles: list[float], |
| | ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]: |
| | """Extract median, lower, and upper bound predictions from quantile output.""" |
| | try: |
| | median_idx = model_quantiles.index(0.5) |
| | lower_idx = model_quantiles.index(0.1) |
| | upper_idx = model_quantiles.index(0.9) |
| |
|
| | median_preds = predicted_values[..., median_idx] |
| | lower_bound = predicted_values[..., lower_idx] |
| | upper_bound = predicted_values[..., upper_idx] |
| |
|
| | return median_preds, lower_bound, upper_bound |
| | except (ValueError, IndexError): |
| | logger.warning("Could not find 0.1, 0.5, 0.9 quantiles for plotting. Using median of available quantiles.") |
| | median_preds = predicted_values[..., predicted_values.shape[-1] // 2] |
| | return median_preds, None, None |
| |
|
| |
|
| | def plot_from_container( |
| | batch: BatchTimeSeriesContainer, |
| | sample_idx: int, |
| | predicted_values: np.ndarray | None = None, |
| | model_quantiles: list[float] | None = None, |
| | title: str | None = None, |
| | output_file: str | None = None, |
| | show: bool = True, |
| | ) -> Figure: |
| | """Plot a single sample from a BatchTimeSeriesContainer with proper quantile handling.""" |
| | |
| | history_values = batch.history_values[sample_idx].cpu().numpy() |
| | future_values = batch.future_values[sample_idx].cpu().numpy() |
| |
|
| | |
| | if predicted_values is not None: |
| | |
| | if predicted_values.ndim >= 3 or ( |
| | predicted_values.ndim == 2 and predicted_values.shape[0] > future_values.shape[0] |
| | ): |
| | sample_preds = predicted_values[sample_idx] |
| | else: |
| | sample_preds = predicted_values |
| |
|
| | |
| | if model_quantiles: |
| | median_preds, lower_bound, upper_bound = _extract_quantile_predictions(sample_preds, model_quantiles) |
| | else: |
| | median_preds = sample_preds |
| | lower_bound = None |
| | upper_bound = None |
| | else: |
| | median_preds = None |
| | lower_bound = None |
| | upper_bound = None |
| |
|
| | |
| | return plot_multivariate_timeseries( |
| | history_values=history_values, |
| | future_values=future_values, |
| | predicted_values=median_preds, |
| | start=batch.start[sample_idx], |
| | frequency=batch.frequency[sample_idx], |
| | title=title, |
| | output_file=output_file, |
| | show=show, |
| | lower_bound=lower_bound, |
| | upper_bound=upper_bound, |
| | ) |
| |
|