| """ |
| PMA-VAE: Parallel Mobile Artistic Variational Autoencoder |
| ========================================================= |
| Attention-free, mobile-deployable VAE with: |
| - Parallel 2D Mamba/SSM blocks (no sequential pixel loops) |
| - Mobile depthwise-separable convolutions |
| - Multi-scale latents: z_base (H/16), z_detail (H/8), z_style (global vector) |
| - FiLM style conditioning throughout decoder |
| - Designed for: image generation, super-resolution, artifact removal, style transfer |
| |
| Architecture: |
| Image → PixelUnshuffle stem → MobileConv + Parallel 2D Mamba encoder |
| → Multi-scale latent (base + detail + style) |
| → Light parallel decoder with FiLM modulation → Reconstructed image |
| |
| Total params target: ~20-40M (encoder heavier, decoder light for mobile) |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
|
|
| |
| |
| |
| |
|
|
| class PScan(torch.autograd.Function): |
| """ |
| Parallel prefix scan (Blelloch algorithm) in pure PyTorch. |
| Computes: y[t] = A[t] * y[t-1] + X[t] for all t in parallel. |
| """ |
| @staticmethod |
| def pscan_forward(A, X): |
| B, D, L, N = A.size() |
| |
| orig_L = L |
| if L & (L - 1) != 0: |
| next_pow2 = 1 << (L - 1).bit_length() |
| pad = next_pow2 - L |
| A = F.pad(A, (0, 0, 0, pad), value=1.0) |
| X = F.pad(X, (0, 0, 0, pad), value=0.0) |
| L = next_pow2 |
|
|
| num_steps = int(math.log2(L)) |
| |
| |
| Aa = A.clone() |
| Xa = X.clone() |
| |
| |
| for k in range(num_steps): |
| step = 1 << (k + 1) |
| half = step // 2 |
| |
| idx = torch.arange(half - 1, L, step, device=A.device) |
| idx_prev = idx - half |
| |
| Xa[:, :, idx] = Aa[:, :, idx] * Xa[:, :, idx_prev] + Xa[:, :, idx] |
| Aa[:, :, idx] = Aa[:, :, idx] * Aa[:, :, idx_prev] |
| |
| |
| for k in range(num_steps - 2, -1, -1): |
| step = 1 << (k + 1) |
| half = step // 2 |
| idx = torch.arange(step - 1, L, step, device=A.device) |
| if idx.numel() > 0 and (idx + half < L).any(): |
| valid = idx + half |
| valid = valid[valid < L] |
| if valid.numel() > 0: |
| src_idx = valid - half |
| Xa[:, :, valid] = Aa[:, :, valid] * Xa[:, :, src_idx] + Xa[:, :, valid] |
|
|
| return Xa[:, :, :orig_L] |
|
|
| @staticmethod |
| def forward(ctx, A_in, X_in): |
| A = A_in.clone() |
| X = X_in.clone() |
| result = PScan.pscan_forward(A, X) |
| ctx.save_for_backward(A_in, X_in, result) |
| return result |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| A_in, X_in, result = ctx.saved_tensors |
| |
| |
| |
| B, D, L, N = A_in.size() |
| |
| grad_A = torch.zeros_like(A_in) |
| grad_X = torch.zeros_like(X_in) |
| |
| |
| grad_h = torch.zeros(B, D, N, device=A_in.device, dtype=A_in.dtype) |
| |
| for t in range(L - 1, -1, -1): |
| grad_h = grad_h + grad_output[:, :, t] |
| grad_X[:, :, t] = grad_h |
| if t > 0: |
| |
| y_prev = result[:, :, t - 1] |
| grad_A[:, :, t] = (grad_h * y_prev).sum(-1, keepdim=True).expand_as(A_in[:, :, t]) |
| grad_h = grad_h * A_in[:, :, t] |
| else: |
| grad_A[:, :, 0] = torch.zeros_like(A_in[:, :, 0]) |
| |
| return grad_A, grad_X |
|
|
| pscan = PScan.apply |
|
|
|
|
| |
| |
| |
|
|
| class SelectiveSSM(nn.Module): |
| """ |
| Selective State Space Model (S6) from Mamba paper. |
| Uses parallel scan for O(L) computation without sequential loops. |
| |
| For 2D images: we flatten H*W to sequence, process with SSM, reshape back. |
| """ |
| def __init__(self, d_model, d_state=16, d_conv=4, expand=2, use_parallel_scan=True): |
| super().__init__() |
| self.d_model = d_model |
| self.d_state = d_state |
| self.d_conv = d_conv |
| self.expand = expand |
| self.d_inner = int(expand * d_model) |
| self.use_parallel_scan = use_parallel_scan |
|
|
| |
| self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) |
| |
| |
| self.conv1d = nn.Conv1d( |
| self.d_inner, self.d_inner, |
| kernel_size=d_conv, bias=True, |
| groups=self.d_inner, padding=d_conv - 1 |
| ) |
| |
| |
| self.x_proj = nn.Linear(self.d_inner, self.d_state * 2 + 1, bias=False) |
| self.dt_proj = nn.Linear(1, self.d_inner, bias=True) |
| |
| |
| A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).expand(self.d_inner, -1) |
| self.A_log = nn.Parameter(torch.log(A)) |
| |
| |
| self.D = nn.Parameter(torch.ones(self.d_inner)) |
| |
| |
| self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) |
| |
| |
| self.norm = nn.RMSNorm(d_model) |
|
|
| def ssm_parallel(self, x): |
| """Parallel scan SSM — no sequential loops.""" |
| B_size, L, D = x.shape |
| |
| A = -torch.exp(self.A_log.float()) |
| D_skip = self.D.float() |
| |
| |
| x_dbl = self.x_proj(x) |
| dt, B_mat, C_mat = x_dbl.split([1, self.d_state, self.d_state], dim=-1) |
| dt = F.softplus(self.dt_proj(dt)) |
| |
| |
| dA = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)) |
| dBx = dt.unsqueeze(-1) * B_mat.unsqueeze(2) * x.unsqueeze(-1) |
| |
| |
| dA = dA.permute(0, 2, 1, 3).contiguous() |
| dBx = dBx.permute(0, 2, 1, 3).contiguous() |
| |
| if self.use_parallel_scan: |
| |
| h = pscan(dA, dBx) |
| else: |
| |
| h = torch.zeros_like(dBx) |
| state = torch.zeros(B_size, self.d_inner, self.d_state, |
| device=x.device, dtype=x.dtype) |
| for t in range(L): |
| state = dA[:, :, t] * state + dBx[:, :, t] |
| h[:, :, t] = state |
| |
| |
| h = h.permute(0, 2, 1, 3) |
| C_mat_exp = C_mat.unsqueeze(2) |
| y = (h * C_mat_exp).sum(-1) |
| y = y + D_skip * x |
| |
| return y |
|
|
| def forward(self, x): |
| """x: (B, L, d_model)""" |
| residual = x |
| x = self.norm(x) |
| |
| |
| xz = self.in_proj(x) |
| x_ssm, z = xz.chunk(2, dim=-1) |
| |
| |
| x_ssm = rearrange(x_ssm, 'b l d -> b d l') |
| x_ssm = self.conv1d(x_ssm)[:, :, :residual.shape[1]] |
| x_ssm = rearrange(x_ssm, 'b d l -> b l d') |
| x_ssm = F.silu(x_ssm) |
| |
| |
| y = self.ssm_parallel(x_ssm) |
| |
| |
| y = y * F.silu(z) |
| |
| return self.out_proj(y) + residual |
|
|
|
|
| |
| |
| |
|
|
| def cross_scan_2d(x): |
| """ |
| Convert 2D feature map to 4 directional 1D sequences. |
| x: (B, H, W, C) |
| Returns: list of 4 tensors, each (B, H*W, C) |
| """ |
| B, H, W, C = x.shape |
| |
| d1 = rearrange(x, 'b h w c -> b (h w) c') |
| |
| d2 = rearrange(x.flip([1, 2]), 'b h w c -> b (h w) c') |
| |
| d3 = rearrange(x.permute(0, 2, 1, 3), 'b w h c -> b (w h) c') |
| |
| d4 = rearrange(x.permute(0, 2, 1, 3).flip([1, 2]), 'b w h c -> b (w h) c') |
| return [d1, d2, d3, d4] |
|
|
|
|
| def cross_merge_2d(ys, H, W): |
| """ |
| Merge 4 directional sequences back to 2D. |
| ys: list of 4 tensors (B, H*W, C) |
| Returns: (B, H, W, C) |
| """ |
| d1 = rearrange(ys[0], 'b (h w) c -> b h w c', h=H, w=W) |
| d2 = rearrange(ys[1], 'b (h w) c -> b h w c', h=H, w=W).flip([1, 2]) |
| d3 = rearrange(ys[2], 'b (h w) c -> b w h c', h=H, w=W).permute(0, 2, 1, 3) |
| d4 = rearrange(ys[3], 'b (h w) c -> b w h c', h=H, w=W).permute(0, 2, 1, 3).flip([1, 2]) |
| return (d1 + d2 + d3 + d4) * 0.25 |
|
|
|
|
| class Mamba2DBlock(nn.Module): |
| """ |
| 2D Mamba block using cross-scan pattern. |
| Processes feature maps with 4 directional SSM scans in parallel. |
| No attention — pure SSM + local conv. |
| """ |
| def __init__(self, channels, d_state=16, expand=2, use_parallel_scan=True): |
| super().__init__() |
| self.channels = channels |
| |
| self.ssm = SelectiveSSM( |
| d_model=channels, |
| d_state=d_state, |
| d_conv=4, |
| expand=expand, |
| use_parallel_scan=use_parallel_scan |
| ) |
| self.mix_proj = nn.Linear(channels, channels) |
| self.norm = nn.RMSNorm(channels) |
|
|
| def forward(self, x): |
| """x: (B, C, H, W)""" |
| B, C, H, W = x.shape |
| residual = x |
| |
| |
| x_hwc = x.permute(0, 2, 3, 1) |
| |
| |
| seqs = cross_scan_2d(x_hwc) |
| |
| |
| outputs = [self.ssm(s) for s in seqs] |
| |
| |
| merged = cross_merge_2d(outputs, H, W) |
| merged = self.norm(merged) |
| merged = self.mix_proj(merged) |
| |
| |
| return merged.permute(0, 3, 1, 2) + residual |
|
|
|
|
| |
| |
| |
|
|
| class SqueezeExcitation(nn.Module): |
| """Channel attention via squeeze-excitation.""" |
| def __init__(self, channels, reduction=4): |
| super().__init__() |
| reduced = max(8, channels // reduction) |
| self.pool = nn.AdaptiveAvgPool2d(1) |
| self.fc = nn.Sequential( |
| nn.Linear(channels, reduced), |
| nn.SiLU(inplace=True), |
| nn.Linear(reduced, channels), |
| nn.Sigmoid() |
| ) |
| |
| def forward(self, x): |
| B, C, H, W = x.shape |
| w = self.pool(x).view(B, C) |
| w = self.fc(w).view(B, C, 1, 1) |
| return x * w |
|
|
|
|
| class FiLM(nn.Module): |
| """Feature-wise Linear Modulation for style conditioning.""" |
| def __init__(self, cond_dim, channels): |
| super().__init__() |
| self.proj = nn.Linear(cond_dim, channels * 2) |
| |
| def forward(self, x, cond): |
| """x: (B,C,H,W), cond: (B, cond_dim)""" |
| params = self.proj(cond) |
| gamma, beta = params.chunk(2, dim=-1) |
| gamma = gamma.view(-1, x.shape[1], 1, 1) |
| beta = beta.view(-1, x.shape[1], 1, 1) |
| return x * (1 + gamma) + beta |
|
|
|
|
| class MobileConvBlock(nn.Module): |
| """ |
| Mobile-friendly inverted residual block with: |
| - Depthwise separable convolution |
| - Squeeze-Excitation |
| - Optional FiLM style conditioning |
| - Reparameterizable for mobile deployment |
| """ |
| def __init__(self, in_ch, out_ch, expand_ratio=4, stride=1, |
| use_se=True, cond_dim=None): |
| super().__init__() |
| mid_ch = in_ch * expand_ratio |
| self.use_residual = (stride == 1 and in_ch == out_ch) |
| |
| layers = [] |
| |
| if expand_ratio != 1: |
| layers.extend([ |
| nn.Conv2d(in_ch, mid_ch, 1, bias=False), |
| nn.BatchNorm2d(mid_ch), |
| nn.SiLU(inplace=True), |
| ]) |
| |
| layers.extend([ |
| nn.Conv2d(mid_ch, mid_ch, 3, stride=stride, padding=1, |
| groups=mid_ch, bias=False), |
| nn.BatchNorm2d(mid_ch), |
| nn.SiLU(inplace=True), |
| ]) |
| self.conv = nn.Sequential(*layers) |
| |
| |
| self.se = SqueezeExcitation(mid_ch) if use_se else nn.Identity() |
| |
| |
| self.project = nn.Sequential( |
| nn.Conv2d(mid_ch, out_ch, 1, bias=False), |
| nn.BatchNorm2d(out_ch), |
| ) |
| |
| |
| self.film = FiLM(cond_dim, out_ch) if cond_dim else None |
| |
| |
| if not self.use_residual and stride == 1: |
| self.skip = nn.Conv2d(in_ch, out_ch, 1, bias=False) |
| elif not self.use_residual: |
| self.skip = nn.Sequential( |
| nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False), |
| nn.BatchNorm2d(out_ch), |
| ) |
| else: |
| self.skip = nn.Identity() |
| |
| def forward(self, x, cond=None): |
| out = self.conv(x) |
| out = self.se(out) |
| out = self.project(out) |
| if self.film is not None and cond is not None: |
| out = self.film(out, cond) |
| if self.use_residual: |
| return out + x |
| else: |
| return out + self.skip(x) if hasattr(self, 'skip') else out |
|
|
|
|
| class GatedConvBlock(nn.Module): |
| """Gated convolution block — alternative to attention for global mixing.""" |
| def __init__(self, channels): |
| super().__init__() |
| self.norm = nn.GroupNorm(min(32, channels), channels) |
| self.proj = nn.Conv2d(channels, channels * 2, 1) |
| self.dw = nn.Conv2d(channels, channels, 5, padding=2, groups=channels) |
| self.out = nn.Conv2d(channels, channels, 1) |
| |
| def forward(self, x): |
| residual = x |
| x = self.norm(x) |
| gate, val = self.proj(x).chunk(2, dim=1) |
| val = self.dw(val) |
| x = val * F.silu(gate) |
| return self.out(x) + residual |
|
|
|
|
| |
| |
| |
|
|
| class PMAEncoder(nn.Module): |
| """ |
| Encoder with progressive downsampling: |
| H → H/2 → H/4 → H/8 → H/16 |
| |
| Outputs multi-scale latents: |
| - z_base: H/16 x W/16 x latent_base_dim |
| - z_detail: H/8 x W/8 x latent_detail_dim |
| - z_style: 1 x 1 x latent_style_dim (global) |
| """ |
| def __init__(self, in_channels=3, |
| stage_channels=(64, 128, 192, 256), |
| stage_blocks=(2, 2, 4, 4), |
| latent_base_dim=32, |
| latent_detail_dim=8, |
| latent_style_dim=128, |
| d_state=16, |
| use_parallel_scan=True): |
| super().__init__() |
| |
| self.latent_base_dim = latent_base_dim |
| self.latent_detail_dim = latent_detail_dim |
| self.latent_style_dim = latent_style_dim |
| |
| |
| self.stem = nn.Sequential( |
| nn.PixelUnshuffle(2), |
| nn.Conv2d(in_channels * 4, stage_channels[0], 3, padding=1, bias=False), |
| nn.BatchNorm2d(stage_channels[0]), |
| nn.SiLU(inplace=True), |
| ) |
| |
| |
| self.stage1 = self._make_mobile_stage( |
| stage_channels[0], stage_channels[1], stage_blocks[0], stride=2 |
| ) |
| |
| |
| self.stage2 = self._make_hybrid_stage( |
| stage_channels[1], stage_channels[2], stage_blocks[1], |
| stride=2, d_state=d_state, mamba_ratio=0.5, |
| use_parallel_scan=use_parallel_scan |
| ) |
| |
| |
| self.detail_head_mu = nn.Conv2d(stage_channels[2], latent_detail_dim, 1) |
| self.detail_head_logvar = nn.Conv2d(stage_channels[2], latent_detail_dim, 1) |
| |
| |
| self.stage3 = self._make_hybrid_stage( |
| stage_channels[2], stage_channels[3], stage_blocks[2], |
| stride=2, d_state=d_state, mamba_ratio=0.75, |
| use_parallel_scan=use_parallel_scan |
| ) |
| |
| |
| self.global_mix = GatedConvBlock(stage_channels[3]) |
| |
| |
| self.base_head_mu = nn.Conv2d(stage_channels[3], latent_base_dim, 1) |
| self.base_head_logvar = nn.Conv2d(stage_channels[3], latent_base_dim, 1) |
| |
| |
| self.style_pool = nn.AdaptiveAvgPool2d(1) |
| self.style_head_mu = nn.Linear(stage_channels[3], latent_style_dim) |
| self.style_head_logvar = nn.Linear(stage_channels[3], latent_style_dim) |
| |
| def _make_mobile_stage(self, in_ch, out_ch, num_blocks, stride=1): |
| blocks = [MobileConvBlock(in_ch, out_ch, stride=stride)] |
| for _ in range(num_blocks - 1): |
| blocks.append(MobileConvBlock(out_ch, out_ch)) |
| return nn.Sequential(*blocks) |
| |
| def _make_hybrid_stage(self, in_ch, out_ch, num_blocks, stride=1, |
| d_state=16, mamba_ratio=0.5, use_parallel_scan=True): |
| blocks = nn.ModuleList() |
| |
| blocks.append(MobileConvBlock(in_ch, out_ch, stride=stride)) |
| |
| num_mamba = max(1, int((num_blocks - 1) * mamba_ratio)) |
| num_mobile = (num_blocks - 1) - num_mamba |
| |
| for _ in range(num_mobile): |
| blocks.append(MobileConvBlock(out_ch, out_ch)) |
| for _ in range(num_mamba): |
| blocks.append(Mamba2DBlock(out_ch, d_state=d_state, expand=2, |
| use_parallel_scan=use_parallel_scan)) |
| return blocks |
| |
| def forward(self, x): |
| """ |
| x: (B, 3, H, W) |
| Returns: dict with mu/logvar for base, detail, style latents |
| """ |
| |
| x = self.stem(x) |
| |
| |
| x = self.stage1(x) |
| |
| |
| for block in self.stage2: |
| x = block(x) |
| |
| |
| detail_mu = self.detail_head_mu(x) |
| detail_logvar = self.detail_head_logvar(x) |
| |
| |
| for block in self.stage3: |
| x = block(x) |
| |
| |
| x = self.global_mix(x) |
| |
| |
| base_mu = self.base_head_mu(x) |
| base_logvar = self.base_head_logvar(x) |
| |
| |
| style_feat = self.style_pool(x).flatten(1) |
| style_mu = self.style_head_mu(style_feat) |
| style_logvar = self.style_head_logvar(style_feat) |
| |
| return { |
| 'base_mu': base_mu, 'base_logvar': base_logvar, |
| 'detail_mu': detail_mu, 'detail_logvar': detail_logvar, |
| 'style_mu': style_mu, 'style_logvar': style_logvar, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class UpsampleBlock(nn.Module): |
| """Efficient 2x upsample with pixel shuffle.""" |
| def __init__(self, in_ch, out_ch): |
| super().__init__() |
| self.conv = nn.Conv2d(in_ch, out_ch * 4, 3, padding=1, bias=False) |
| self.ps = nn.PixelShuffle(2) |
| self.norm = nn.BatchNorm2d(out_ch) |
| self.act = nn.SiLU(inplace=True) |
| |
| def forward(self, x): |
| return self.act(self.norm(self.ps(self.conv(x)))) |
|
|
|
|
| class PMADecoder(nn.Module): |
| """ |
| Lightweight decoder for mobile deployment. |
| |
| Takes multi-scale latents and reconstructs image: |
| z_base (H/16) + z_style → decode → fuse z_detail (H/8) → upsample → image |
| """ |
| def __init__(self, out_channels=3, |
| stage_channels=(256, 192, 128, 96, 64), |
| latent_base_dim=32, |
| latent_detail_dim=8, |
| latent_style_dim=128, |
| d_state=16, |
| use_parallel_scan=True): |
| super().__init__() |
| |
| |
| self.base_proj = nn.Sequential( |
| nn.Conv2d(latent_base_dim, stage_channels[0], 3, padding=1, bias=False), |
| nn.BatchNorm2d(stage_channels[0]), |
| nn.SiLU(inplace=True), |
| ) |
| |
| |
| self.stage1_blocks = nn.ModuleList([ |
| MobileConvBlock(stage_channels[0], stage_channels[0], |
| cond_dim=latent_style_dim), |
| Mamba2DBlock(stage_channels[0], d_state=d_state, |
| use_parallel_scan=use_parallel_scan), |
| ]) |
| |
| |
| self.up1 = UpsampleBlock(stage_channels[0], stage_channels[1]) |
| |
| |
| self.detail_fuse = nn.Sequential( |
| nn.Conv2d(stage_channels[1] + latent_detail_dim, stage_channels[1], 1, bias=False), |
| nn.BatchNorm2d(stage_channels[1]), |
| nn.SiLU(inplace=True), |
| ) |
| |
| |
| self.stage2_blocks = nn.ModuleList([ |
| MobileConvBlock(stage_channels[1], stage_channels[1], |
| cond_dim=latent_style_dim), |
| MobileConvBlock(stage_channels[1], stage_channels[1], |
| cond_dim=latent_style_dim), |
| Mamba2DBlock(stage_channels[1], d_state=d_state, |
| use_parallel_scan=use_parallel_scan), |
| ]) |
| |
| |
| self.up2 = UpsampleBlock(stage_channels[1], stage_channels[2]) |
| |
| |
| self.stage3_blocks = nn.ModuleList([ |
| MobileConvBlock(stage_channels[2], stage_channels[2], |
| cond_dim=latent_style_dim), |
| MobileConvBlock(stage_channels[2], stage_channels[2]), |
| ]) |
| |
| |
| self.up3 = UpsampleBlock(stage_channels[2], stage_channels[3]) |
| |
| |
| self.stage4_blocks = nn.ModuleList([ |
| MobileConvBlock(stage_channels[3], stage_channels[3]), |
| MobileConvBlock(stage_channels[3], stage_channels[3]), |
| ]) |
| |
| |
| self.up4 = UpsampleBlock(stage_channels[3], stage_channels[4]) |
| |
| |
| self.head = nn.Sequential( |
| nn.Conv2d(stage_channels[4], stage_channels[4], 3, padding=1), |
| nn.SiLU(inplace=True), |
| nn.Conv2d(stage_channels[4], out_channels, 3, padding=1), |
| nn.Tanh(), |
| ) |
| |
| def forward(self, z_base, z_detail, z_style): |
| """ |
| z_base: (B, latent_base_dim, H/16, W/16) |
| z_detail: (B, latent_detail_dim, H/8, W/8) |
| z_style: (B, latent_style_dim) |
| """ |
| |
| x = self.base_proj(z_base) |
| |
| |
| for block in self.stage1_blocks: |
| if isinstance(block, MobileConvBlock): |
| x = block(x, cond=z_style) |
| else: |
| x = block(x) |
| |
| |
| x = self.up1(x) |
| |
| |
| x = self.detail_fuse(torch.cat([x, z_detail], dim=1)) |
| |
| |
| for block in self.stage2_blocks: |
| if isinstance(block, MobileConvBlock): |
| x = block(x, cond=z_style) |
| else: |
| x = block(x) |
| |
| |
| x = self.up2(x) |
| |
| |
| for block in self.stage3_blocks: |
| if isinstance(block, MobileConvBlock): |
| x = block(x, cond=z_style) |
| else: |
| x = block(x) |
| |
| |
| x = self.up3(x) |
| |
| |
| for block in self.stage4_blocks: |
| x = block(x) |
| |
| |
| x = self.up4(x) |
| |
| |
| return self.head(x) |
|
|
|
|
| |
| |
| |
|
|
| class PMAVAE(nn.Module): |
| """ |
| Parallel Mobile Artistic VAE — Full model. |
| |
| Features: |
| - Attention-free (Mamba SSM + mobile convolutions) |
| - Multi-scale latent space (base + detail + style) |
| - FiLM style conditioning in decoder |
| - Parallel scan training (no sequential pixel loops) |
| - Mobile-deployable decoder (~15-20M params) |
| |
| Args: |
| in_channels: Input image channels (3 for RGB) |
| enc_channels: Channel widths per encoder stage |
| dec_channels: Channel widths per decoder stage |
| latent_base_dim: Channels for H/16 base latent |
| latent_detail_dim: Channels for H/8 detail latent |
| latent_style_dim: Dimension of global style vector |
| d_state: SSM state dimension |
| use_parallel_scan: Use Blelloch parallel scan (True) or sequential (False) |
| """ |
| def __init__(self, |
| in_channels=3, |
| enc_channels=(64, 128, 192, 256), |
| dec_channels=(256, 192, 128, 96, 64), |
| enc_blocks=(2, 2, 4, 4), |
| latent_base_dim=32, |
| latent_detail_dim=8, |
| latent_style_dim=128, |
| d_state=16, |
| use_parallel_scan=True): |
| super().__init__() |
| |
| self.encoder = PMAEncoder( |
| in_channels=in_channels, |
| stage_channels=enc_channels, |
| stage_blocks=enc_blocks, |
| latent_base_dim=latent_base_dim, |
| latent_detail_dim=latent_detail_dim, |
| latent_style_dim=latent_style_dim, |
| d_state=d_state, |
| use_parallel_scan=use_parallel_scan, |
| ) |
| |
| self.decoder = PMADecoder( |
| out_channels=in_channels, |
| stage_channels=dec_channels, |
| latent_base_dim=latent_base_dim, |
| latent_detail_dim=latent_detail_dim, |
| latent_style_dim=latent_style_dim, |
| d_state=d_state, |
| use_parallel_scan=use_parallel_scan, |
| ) |
| |
| def reparameterize(self, mu, logvar): |
| """Reparameterization trick: z = mu + eps * std""" |
| if self.training: |
| std = torch.exp(0.5 * logvar) |
| eps = torch.randn_like(std) |
| return mu + eps * std |
| return mu |
| |
| def encode(self, x): |
| """Encode image to multi-scale latent distributions.""" |
| posteriors = self.encoder(x) |
| return posteriors |
| |
| def decode(self, z_base, z_detail, z_style): |
| """Decode latents to image.""" |
| return self.decoder(z_base, z_detail, z_style) |
| |
| def forward(self, x): |
| """ |
| Full forward pass: encode → sample → decode. |
| Returns: (recon, posteriors_dict) |
| """ |
| posteriors = self.encode(x) |
| |
| |
| z_base = self.reparameterize(posteriors['base_mu'], posteriors['base_logvar']) |
| z_detail = self.reparameterize(posteriors['detail_mu'], posteriors['detail_logvar']) |
| z_style = self.reparameterize(posteriors['style_mu'], posteriors['style_logvar']) |
| |
| |
| recon = self.decode(z_base, z_detail, z_style) |
| |
| return recon, posteriors |
| |
| def get_last_decoder_layer(self): |
| """For adaptive discriminator weight balancing.""" |
| return self.decoder.head[-2].weight |
| |
| @torch.no_grad() |
| def encode_to_latent(self, x): |
| """Encode to deterministic latent (use mu, no sampling).""" |
| posteriors = self.encode(x) |
| return (posteriors['base_mu'], posteriors['detail_mu'], posteriors['style_mu']) |
| |
| @torch.no_grad() |
| def decode_from_latent(self, z_base, z_detail, z_style): |
| """Decode from latents (inference mode).""" |
| return self.decode(z_base, z_detail, z_style) |
| |
| def count_parameters(self): |
| """Count and display parameter breakdown.""" |
| enc_params = sum(p.numel() for p in self.encoder.parameters()) |
| dec_params = sum(p.numel() for p in self.decoder.parameters()) |
| total = enc_params + dec_params |
| return { |
| 'encoder': enc_params, |
| 'decoder': dec_params, |
| 'total': total, |
| 'encoder_M': enc_params / 1e6, |
| 'decoder_M': dec_params / 1e6, |
| 'total_M': total / 1e6, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def pmavae_tiny(**kwargs): |
| """Tiny config for testing. ~5M params.""" |
| return PMAVAE( |
| enc_channels=(32, 64, 96, 128), |
| dec_channels=(128, 96, 64, 48, 32), |
| enc_blocks=(1, 1, 2, 2), |
| latent_base_dim=16, |
| latent_detail_dim=4, |
| latent_style_dim=64, |
| d_state=8, |
| **kwargs |
| ) |
|
|
|
|
| def pmavae_small(**kwargs): |
| """Small config for Colab free tier. ~20M params.""" |
| return PMAVAE( |
| enc_channels=(48, 96, 144, 192), |
| dec_channels=(192, 144, 96, 72, 48), |
| enc_blocks=(2, 2, 3, 3), |
| latent_base_dim=24, |
| latent_detail_dim=6, |
| latent_style_dim=96, |
| d_state=16, |
| **kwargs |
| ) |
|
|
|
|
| def pmavae_base(**kwargs): |
| """Base config. ~40M params.""" |
| return PMAVAE( |
| enc_channels=(64, 128, 192, 256), |
| dec_channels=(256, 192, 128, 96, 64), |
| enc_blocks=(2, 2, 4, 4), |
| latent_base_dim=32, |
| latent_detail_dim=8, |
| latent_style_dim=128, |
| d_state=16, |
| **kwargs |
| ) |
|
|
|
|
| if __name__ == '__main__': |
| |
| device = 'cpu' |
| model = pmavae_tiny(use_parallel_scan=False).to(device) |
| |
| x = torch.randn(2, 3, 256, 256, device=device) |
| recon, posteriors = model(x) |
| |
| print(f"Input: {x.shape}") |
| print(f"Recon: {recon.shape}") |
| for k, v in posteriors.items(): |
| print(f" {k}: {v.shape}") |
| |
| params = model.count_parameters() |
| print(f"\nParams: {params['total_M']:.2f}M (enc: {params['encoder_M']:.2f}M, dec: {params['decoder_M']:.2f}M)") |
|
|