| """mDiffAE encoder: patchify -> DiCoBlocks -> bottleneck projection.""" |
|
|
| from __future__ import annotations |
|
|
| from torch import Tensor, nn |
|
|
| from .dico_block import DiCoBlock |
| from .norms import ChannelWiseRMSNorm |
| from .straight_through_encoder import Patchify |
|
|
|
|
| class Encoder(nn.Module): |
| """Deterministic encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w]. |
| |
| Pipeline: Patchify -> RMSNorm -> DiCoBlocks (unconditioned) -> Conv1x1 -> RMSNorm(no affine) |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| patch_size: int, |
| model_dim: int, |
| depth: int, |
| bottleneck_dim: int, |
| mlp_ratio: float, |
| depthwise_kernel_size: int, |
| ) -> None: |
| super().__init__() |
| self.patchify = Patchify(in_channels, patch_size, model_dim) |
| self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True) |
| self.blocks = nn.ModuleList( |
| [ |
| DiCoBlock( |
| model_dim, |
| mlp_ratio, |
| depthwise_kernel_size=depthwise_kernel_size, |
| use_external_adaln=False, |
| ) |
| for _ in range(depth) |
| ] |
| ) |
| self.to_bottleneck = nn.Conv2d( |
| model_dim, bottleneck_dim, kernel_size=1, bias=True |
| ) |
| self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False) |
|
|
| def forward(self, images: Tensor) -> Tensor: |
| """Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w].""" |
| z = self.patchify(images) |
| z = self.norm_in(z) |
| for block in self.blocks: |
| z = block(z) |
| z = self.to_bottleneck(z) |
| z = self.norm_out(z) |
| return z |
|
|