weatherforecast1024's picture
Upload folder using huggingface_hub
7667a87 verified
from typing import Dict, Any
import torch
from .alignment_utils import get_sample_align_fn
from models.core_model.cuboid_transformer import NoisyCuboidTransformerEncoder
class SEVIRAvgIntensityAlignment():
def __init__(
self,
alignment_type: str = "avg_x",
guide_scale: float = 1.0,
model_type: str = "cuboid",
model_args: Dict[str, Any] = None,
model_ckpt_path: str = None,
):
r"""
Parameters
----------
alignment_type: str
guide_scale: float
model_type: str
model_args: Dict[str, Any]
model_ckpt_path: str
if not None, load the model from the checkpoint
"""
super().__init__()
assert alignment_type in ["avg_x", ], f"alignment_type {alignment_type} is not supported"
self.alignment_type = alignment_type
self.guide_scale = guide_scale
model_args = model_args if model_args is not None else {}
if model_type == "cuboid":
self.model = NoisyCuboidTransformerEncoder(**model_args)
else:
raise NotImplementedError(f"model_type={model_type} is not implemented")
if model_ckpt_path is not None:
self.model.load_state_dict(torch.load(model_ckpt_path, map_location="cpu"))
@classmethod
def model_objective(cls, x, y=None, **kwargs):
r"""
Parameters
----------
x: torch.Tensor
shape = (b t h w c)
Returns
-------
avg: torch.Tensor
shape = (b t 1)
"""
b, t, _, _, _ = x.shape
return torch.mean(x, dim=[2, 3, 4], keepdim=False).unsqueeze(-1)
def alignment_fn(self, zt, t, y=None, zc=None, **kwargs):
r"""
transform the learned model to the final guidance \mathcal{F}.
Parameters
----------
zt: torch.Tensor
noisy latent z
t: torch.Tensor
timestamp
y: torch.Tensor
context sequence in pixel space
zc: torch.Tensor
encoded context sequence in latente space
kwargs: Dict[str, Any]
auxiliary knowledge for guided generation
`avg_x_gt`: float is required.
Returns
-------
ret: torch.Tensor
"""
pred = self.model(zt, t, zc=zc, y=y, **kwargs)
if self.alignment_type == "avg_x":
target = kwargs.get("avg_x_gt")
else:
raise NotImplementedError
pred = pred.mean(dim=1) # b t 1 -> b 1
ret = torch.linalg.vector_norm(pred - target, ord=2)
return ret
def get_mean_shift(self, zt, t, y=None, zc=None, **kwargs):
r"""
Parameters
----------
zt: torch.Tensor
noisy latent z
t: torch.Tensor
timestamp
y: torch.Tensor
context sequence in pixel space
zc: torch.Tensor
encoded context sequence in latente space
Returns
-------
ret: torch.Tensor
\nabla_zt U
"""
grad_fn = get_sample_align_fn(self.alignment_fn)
grad = grad_fn(zt, t, y=y, zc=zc, **kwargs)
return self.guide_scale * grad