from typing import Dict, Any import torch from .alignment_utils import get_sample_align_fn from models.core_model.cuboid_transformer import NoisyCuboidTransformerEncoder class SEVIRAvgIntensityAlignment(): def __init__( self, alignment_type: str = "avg_x", guide_scale: float = 1.0, model_type: str = "cuboid", model_args: Dict[str, Any] = None, model_ckpt_path: str = None, ): r""" Parameters ---------- alignment_type: str guide_scale: float model_type: str model_args: Dict[str, Any] model_ckpt_path: str if not None, load the model from the checkpoint """ super().__init__() assert alignment_type in ["avg_x", ], f"alignment_type {alignment_type} is not supported" self.alignment_type = alignment_type self.guide_scale = guide_scale model_args = model_args if model_args is not None else {} if model_type == "cuboid": self.model = NoisyCuboidTransformerEncoder(**model_args) else: raise NotImplementedError(f"model_type={model_type} is not implemented") if model_ckpt_path is not None: self.model.load_state_dict(torch.load(model_ckpt_path, map_location="cpu")) @classmethod def model_objective(cls, x, y=None, **kwargs): r""" Parameters ---------- x: torch.Tensor shape = (b t h w c) Returns ------- avg: torch.Tensor shape = (b t 1) """ b, t, _, _, _ = x.shape return torch.mean(x, dim=[2, 3, 4], keepdim=False).unsqueeze(-1) def alignment_fn(self, zt, t, y=None, zc=None, **kwargs): r""" transform the learned model to the final guidance \mathcal{F}. Parameters ---------- zt: torch.Tensor noisy latent z t: torch.Tensor timestamp y: torch.Tensor context sequence in pixel space zc: torch.Tensor encoded context sequence in latente space kwargs: Dict[str, Any] auxiliary knowledge for guided generation `avg_x_gt`: float is required. Returns ------- ret: torch.Tensor """ pred = self.model(zt, t, zc=zc, y=y, **kwargs) if self.alignment_type == "avg_x": target = kwargs.get("avg_x_gt") else: raise NotImplementedError pred = pred.mean(dim=1) # b t 1 -> b 1 ret = torch.linalg.vector_norm(pred - target, ord=2) return ret def get_mean_shift(self, zt, t, y=None, zc=None, **kwargs): r""" Parameters ---------- zt: torch.Tensor noisy latent z t: torch.Tensor timestamp y: torch.Tensor context sequence in pixel space zc: torch.Tensor encoded context sequence in latente space Returns ------- ret: torch.Tensor \nabla_zt U """ grad_fn = get_sample_align_fn(self.alignment_fn) grad = grad_fn(zt, t, y=y, zc=zc, **kwargs) return self.guide_scale * grad