| """mDiffAE v2 decoder: skip-concat topology with dual PDG (token masking + path drop). |
| |
| No outer RMSNorms (use_other_outer_rms_norms=False during training): |
| norm_in, latent_norm, and norm_out are all absent. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
| from .adaln import AdaLNZeroLowRankDelta, AdaLNZeroProjector |
| from .dico_block import DiCoBlock |
| from .straight_through_encoder import Patchify |
| from .time_embed import SinusoidalTimeEmbeddingMLP |
|
|
|
|
| class Decoder(nn.Module): |
| """VP diffusion decoder conditioned on encoder latents and timestep. |
| |
| Architecture (skip-concat, 2+4+2 default): |
| Patchify x_t -> Fuse with upsampled z |
| -> Start blocks (2) -> Middle blocks (4) -> Skip fuse -> End blocks (2) |
| -> Conv1x1 -> PixelShuffle |
| |
| Dual PDG at inference: |
| - Path drop: replace middle block output with ``path_drop_mask_feature``. |
| - Token mask: replace a fraction of upsampled latent tokens with |
| ``latent_mask_feature`` before fusion. |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| patch_size: int, |
| model_dim: int, |
| depth: int, |
| start_block_count: int, |
| end_block_count: int, |
| bottleneck_dim: int, |
| mlp_ratio: float, |
| depthwise_kernel_size: int, |
| adaln_low_rank_rank: int, |
| pdg_mask_ratio: float = 0.75, |
| ) -> None: |
| super().__init__() |
| self.patch_size = int(patch_size) |
| self.model_dim = int(model_dim) |
| self.pdg_mask_ratio = float(pdg_mask_ratio) |
|
|
| |
| self.patchify = Patchify(in_channels, patch_size, model_dim) |
|
|
| |
| self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True) |
| self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True) |
|
|
| |
| self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim) |
|
|
| |
| self.adaln_base = AdaLNZeroProjector(d_model=model_dim, d_cond=model_dim) |
| self.adaln_deltas = nn.ModuleList( |
| [ |
| AdaLNZeroLowRankDelta( |
| d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank |
| ) |
| for _ in range(depth) |
| ] |
| ) |
|
|
| |
| middle_count = depth - start_block_count - end_block_count |
| self._middle_start_idx = start_block_count |
| self._end_start_idx = start_block_count + middle_count |
|
|
| def _make_blocks(count: int) -> nn.ModuleList: |
| return nn.ModuleList( |
| [ |
| DiCoBlock( |
| model_dim, |
| mlp_ratio, |
| depthwise_kernel_size=depthwise_kernel_size, |
| use_external_adaln=True, |
| ) |
| for _ in range(count) |
| ] |
| ) |
|
|
| self.start_blocks = _make_blocks(start_block_count) |
| self.middle_blocks = _make_blocks(middle_count) |
| self.fuse_skip = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True) |
| self.end_blocks = _make_blocks(end_block_count) |
|
|
| |
| self.latent_mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1))) |
| self.path_drop_mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1))) |
|
|
| |
| self.out_proj = nn.Conv2d( |
| model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True |
| ) |
| self.unpatchify = nn.PixelShuffle(patch_size) |
|
|
| def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor: |
| """Compute packed AdaLN modulation = shared_base + per-layer delta.""" |
| act = self.adaln_base.act(cond) |
| base_m = self.adaln_base.forward_activated(act) |
| delta_m = self.adaln_deltas[layer_idx](act) |
| return base_m + delta_m |
|
|
| def _run_blocks( |
| self, blocks: nn.ModuleList, x: Tensor, cond: Tensor, start_index: int |
| ) -> Tensor: |
| """Run a group of decoder blocks with per-block AdaLN modulation.""" |
| for local_idx, block in enumerate(blocks): |
| adaln_m = self._adaln_m_for_layer(cond, layer_idx=start_index + local_idx) |
| x = block(x, adaln_m=adaln_m) |
| return x |
|
|
| def _apply_latent_token_mask(self, z_up: Tensor) -> Tensor: |
| """Replace a fraction of upsampled latent tokens with latent_mask_feature. |
| |
| Uses 2x2 groupwise masking: divides the spatial grid into 2x2 groups |
| and masks floor(ratio * 4) tokens per group (lowest random scores). |
| |
| Args: |
| z_up: [B, C, H, W] upsampled latent conditioning. |
| |
| Returns: |
| Masked tensor with same shape. |
| """ |
| b, c, h, w = z_up.shape |
| |
| h_pad = (2 - h % 2) % 2 |
| w_pad = (2 - w % 2) % 2 |
| if h_pad > 0 or w_pad > 0: |
| z_up = torch.nn.functional.pad(z_up, (0, w_pad, 0, h_pad)) |
| _, _, h, w = z_up.shape |
|
|
| |
| x = z_up.reshape(b, c, h // 2, 2, w // 2, 2) |
| x = x.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h // 2, w // 2, 4) |
|
|
| |
| scores = torch.rand(b, 1, h // 2, w // 2, 4, device=z_up.device) |
|
|
| |
| num_mask = math.floor(self.pdg_mask_ratio * 4) |
| if num_mask > 0: |
| _, indices = scores.sort(dim=-1) |
| mask = torch.zeros_like(scores, dtype=torch.bool) |
| mask.scatter_(-1, indices[..., :num_mask], True) |
| else: |
| mask = torch.zeros_like(scores, dtype=torch.bool) |
|
|
| |
| mask_feat = self.latent_mask_feature.to(device=z_up.device, dtype=z_up.dtype) |
| mask_feat = mask_feat.squeeze(-1).squeeze(-1) |
| mask_feat = mask_feat.view(1, c, 1, 1, 1).expand_as(x) |
| mask_expanded = mask.expand_as(x) |
| x = torch.where(mask_expanded, mask_feat, x) |
|
|
| |
| x = x.reshape(b, c, h // 2, w // 2, 2, 2) |
| x = x.permute(0, 1, 2, 4, 3, 5).reshape(b, c, h, w) |
|
|
| |
| if h_pad > 0 or w_pad > 0: |
| x = x[:, :, : h - h_pad, : w - w_pad] |
|
|
| return x |
|
|
| def forward( |
| self, |
| x_t: Tensor, |
| t: Tensor, |
| latents: Tensor, |
| *, |
| drop_middle_blocks: bool = False, |
| mask_latent_tokens: bool = False, |
| ) -> Tensor: |
| """Single decoder forward pass. |
| |
| Args: |
| x_t: Noised image [B, C, H, W]. |
| t: Timestep [B] in [0, 1]. |
| latents: Encoder latents [B, bottleneck_dim, h, w]. |
| drop_middle_blocks: If True, replace middle block output with |
| path_drop_mask_feature (for path-drop PDG). |
| mask_latent_tokens: If True, mask a fraction of upsampled latent |
| tokens with latent_mask_feature (for token-mask PDG). |
| |
| Returns: |
| x0 prediction [B, C, H, W]. |
| """ |
| |
| x_feat = self.patchify(x_t) |
|
|
| |
| z_up = self.latent_up(latents) |
|
|
| |
| if mask_latent_tokens: |
| z_up = self._apply_latent_token_mask(z_up) |
|
|
| |
| fused = torch.cat([x_feat, z_up], dim=1) |
| fused = self.fuse_in(fused) |
|
|
| |
| cond = self.time_embed(t.to(torch.float32).to(device=x_t.device)) |
|
|
| |
| start_out = self._run_blocks(self.start_blocks, fused, cond, start_index=0) |
|
|
| |
| if drop_middle_blocks: |
| middle_out = self.path_drop_mask_feature.to( |
| device=x_t.device, dtype=x_t.dtype |
| ).expand_as(start_out) |
| else: |
| middle_out = self._run_blocks( |
| self.middle_blocks, |
| start_out, |
| cond, |
| start_index=self._middle_start_idx, |
| ) |
|
|
| |
| skip_fused = torch.cat([start_out, middle_out], dim=1) |
| skip_fused = self.fuse_skip(skip_fused) |
|
|
| |
| end_out = self._run_blocks( |
| self.end_blocks, skip_fused, cond, start_index=self._end_start_idx |
| ) |
|
|
| |
| patches = self.out_proj(end_out) |
| return self.unpatchify(patches) |
|
|