| | """
|
| | From https://github.com/CompVis/latent-diffusion/main/ldm/models/diffusion/ddpm.py
|
| | Pared down to simplify code.
|
| |
|
| | The original file acknowledges:
|
| | https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
| | https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
|
| | https://github.com/CompVis/taming-transformers
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import numpy as np
|
| | import pytorch_lightning as pl
|
| | from contextlib import contextmanager
|
| | from functools import partial
|
| | from torchmetrics import MeanSquaredError
|
| |
|
| | from .utils import make_beta_schedule, extract_into_tensor, noise_like, timestep_embedding
|
| | from .ema import LitEma
|
| | from ..blocks.afno import PatchEmbed3d, PatchExpand3d, AFNOBlock3d
|
| |
|
| |
|
| | class LatentDiffusion(pl.LightningModule):
|
| | def __init__(self,
|
| | model,
|
| | autoencoder,
|
| | context_encoder=None,
|
| | timesteps=1000,
|
| | beta_schedule="linear",
|
| | loss_type="l2",
|
| | use_ema=True,
|
| | lr=1e-4,
|
| | lr_warmup=0,
|
| | linear_start=1e-4,
|
| | linear_end=2e-2,
|
| | cosine_s=8e-3,
|
| | parameterization="eps",
|
| | ):
|
| | super().__init__()
|
| | self.model = model
|
| | self.autoencoder = autoencoder.requires_grad_(False)
|
| | self.conditional = (context_encoder is not None)
|
| | self.context_encoder = context_encoder
|
| | self.lr = lr
|
| | self.lr_warmup = lr_warmup
|
| |
|
| | self.val_loss = MeanSquaredError()
|
| |
|
| | assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
|
| | self.parameterization = parameterization
|
| |
|
| | self.use_ema = use_ema
|
| | if self.use_ema:
|
| | self.model_ema = LitEma(self.model)
|
| |
|
| | self.register_schedule(
|
| | beta_schedule=beta_schedule, timesteps=timesteps,
|
| | linear_start=linear_start, linear_end=linear_end,
|
| | cosine_s=cosine_s
|
| | )
|
| |
|
| | self.loss_type = loss_type
|
| |
|
| | def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
| | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
| |
|
| | betas = make_beta_schedule(
|
| | beta_schedule, timesteps,
|
| | linear_start=linear_start, linear_end=linear_end,
|
| | cosine_s=cosine_s
|
| | )
|
| | alphas = 1. - betas
|
| | alphas_cumprod = np.cumprod(alphas, axis=0)
|
| | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
| |
|
| | timesteps, = betas.shape
|
| | self.num_timesteps = int(timesteps)
|
| | self.linear_start = linear_start
|
| | self.linear_end = linear_end
|
| | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
| |
|
| | to_torch = partial(torch.tensor, dtype=torch.float32)
|
| |
|
| | self.register_buffer('betas', to_torch(betas))
|
| | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
| | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
| |
|
| |
|
| | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
| | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
| |
|
| | @contextmanager
|
| | def ema_scope(self, context=None):
|
| | if self.use_ema:
|
| | self.model_ema.store(self.model.parameters())
|
| | self.model_ema.copy_to(self.model)
|
| | if context is not None:
|
| | print(f"{context}: Switched to EMA weights")
|
| | try:
|
| | yield None
|
| | finally:
|
| | if self.use_ema:
|
| | self.model_ema.restore(self.model.parameters())
|
| | if context is not None:
|
| | print(f"{context}: Restored training weights")
|
| |
|
| | def apply_model(self, x_noisy, t, cond=None, return_ids=False):
|
| | if self.conditional:
|
| | cond = self.context_encoder(cond)
|
| | with self.ema_scope():
|
| | return self.model(x_noisy, t, context=cond)
|
| |
|
| | def q_sample(self, x_start, t, noise=None):
|
| | if noise is None:
|
| | noise = torch.randn_like(x_start)
|
| | return (
|
| | extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
| | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
| | )
|
| |
|
| | def get_loss(self, pred, target, mean=True):
|
| | if self.loss_type == 'l1':
|
| | loss = (target - pred).abs()
|
| | if mean:
|
| | loss = loss.mean()
|
| | elif self.loss_type == 'l2':
|
| | if mean:
|
| | loss = torch.nn.functional.mse_loss(target, pred)
|
| | else:
|
| | loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
| | else:
|
| | raise NotImplementedError("unknown loss type '{loss_type}'")
|
| |
|
| | return loss
|
| |
|
| | def p_losses(self, x_start, t, noise=None, context=None):
|
| | if noise is None:
|
| | noise = torch.randn_like(x_start)
|
| | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
| | model_out = self.model(x_noisy, t, context=context)
|
| |
|
| | if self.parameterization == "eps":
|
| | target = noise
|
| | elif self.parameterization == "x0":
|
| | target = x_start
|
| | else:
|
| | raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
|
| |
|
| | return self.get_loss(model_out, target, mean=False).mean()
|
| |
|
| | def forward(self, x, *args, **kwargs):
|
| | t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
| | return self.p_losses(x, t, *args, **kwargs)
|
| |
|
| | def shared_step(self, batch):
|
| | (x,y) = batch
|
| | y = self.autoencoder.encode(y)[0]
|
| | context = self.context_encoder(x) if self.conditional else None
|
| | return self(y, context=context)
|
| |
|
| | def training_step(self, batch, batch_idx):
|
| | loss = self.shared_step(batch)
|
| | self.log("train_loss", loss)
|
| | return loss
|
| |
|
| | @torch.no_grad()
|
| | def validation_step(self, batch, batch_idx):
|
| |
|
| |
|
| |
|
| | loss = self.shared_step(batch)
|
| | with self.ema_scope():
|
| | loss_ema = self.shared_step(batch)
|
| | log_params = {"on_step": False, "on_epoch": True, "prog_bar": True}
|
| | self.log("val_loss", loss, **log_params)
|
| | self.log("val_loss_ema", loss, **log_params)
|
| |
|
| |
|
| | def test_step(self, batch, batch_idx):
|
| | return self.validation_step(batch, batch_idx)
|
| |
|
| | def on_train_batch_end(self, *args, **kwargs):
|
| | if self.use_ema:
|
| | self.model_ema(self.model)
|
| |
|
| | def configure_optimizers(self):
|
| | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr,
|
| | betas=(0.5, 0.9), weight_decay=1e-3)
|
| | reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| | optimizer, patience=3, factor=0.25, verbose=True
|
| | )
|
| | return {
|
| | "optimizer": optimizer,
|
| | "lr_scheduler": {
|
| | "scheduler": reduce_lr,
|
| | "monitor": "val_loss_ema",
|
| | "frequency": 1,
|
| | },
|
| | }
|
| |
|
| | def optimizer_step(
|
| | self,
|
| | epoch,
|
| | batch_idx,
|
| | optimizer,
|
| | optimizer_idx,
|
| |
|
| | **kwargs
|
| | ):
|
| | if self.trainer.global_step < self.lr_warmup:
|
| | lr_scale = (self.trainer.global_step+1) / self.lr_warmup
|
| | for pg in optimizer.param_groups:
|
| | pg['lr'] = lr_scale * self.lr
|
| |
|
| | super().optimizer_step(
|
| | epoch, batch_idx, optimizer,
|
| | optimizer_idx,
|
| |
|
| | **kwargs
|
| | )
|
| |
|
| |
|