PMA-VAE / model.py
krystv's picture
Upload model.py with huggingface_hub
92f566d verified
"""
PMA-VAE: Parallel Mobile Artistic Variational Autoencoder
=========================================================
Attention-free, mobile-deployable VAE with:
- Parallel 2D Mamba/SSM blocks (no sequential pixel loops)
- Mobile depthwise-separable convolutions
- Multi-scale latents: z_base (H/16), z_detail (H/8), z_style (global vector)
- FiLM style conditioning throughout decoder
- Designed for: image generation, super-resolution, artifact removal, style transfer
Architecture:
Image → PixelUnshuffle stem → MobileConv + Parallel 2D Mamba encoder
→ Multi-scale latent (base + detail + style)
→ Light parallel decoder with FiLM modulation → Reconstructed image
Total params target: ~20-40M (encoder heavier, decoder light for mobile)
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
# ==============================================================================
# Parallel Scan (Blelloch-style) — Pure PyTorch, no CUDA kernels
# Based on: https://github.com/alxndrTL/mamba.py/blob/main/mambapy/pscan.py
# ==============================================================================
class PScan(torch.autograd.Function):
"""
Parallel prefix scan (Blelloch algorithm) in pure PyTorch.
Computes: y[t] = A[t] * y[t-1] + X[t] for all t in parallel.
"""
@staticmethod
def pscan_forward(A, X):
B, D, L, N = A.size()
# Pad to next power of 2 if needed
orig_L = L
if L & (L - 1) != 0: # not power of 2
next_pow2 = 1 << (L - 1).bit_length()
pad = next_pow2 - L
A = F.pad(A, (0, 0, 0, pad), value=1.0)
X = F.pad(X, (0, 0, 0, pad), value=0.0)
L = next_pow2
num_steps = int(math.log2(L))
# Store intermediate values for down-sweep
Aa = A.clone()
Xa = X.clone()
# Up-sweep (reduce)
for k in range(num_steps):
step = 1 << (k + 1)
half = step // 2
# Indices for even/odd pairs
idx = torch.arange(half - 1, L, step, device=A.device)
idx_prev = idx - half
Xa[:, :, idx] = Aa[:, :, idx] * Xa[:, :, idx_prev] + Xa[:, :, idx]
Aa[:, :, idx] = Aa[:, :, idx] * Aa[:, :, idx_prev]
# Down-sweep
for k in range(num_steps - 2, -1, -1):
step = 1 << (k + 1)
half = step // 2
idx = torch.arange(step - 1, L, step, device=A.device)
if idx.numel() > 0 and (idx + half < L).any():
valid = idx + half
valid = valid[valid < L]
if valid.numel() > 0:
src_idx = valid - half
Xa[:, :, valid] = Aa[:, :, valid] * Xa[:, :, src_idx] + Xa[:, :, valid]
return Xa[:, :, :orig_L]
@staticmethod
def forward(ctx, A_in, X_in):
A = A_in.clone()
X = X_in.clone()
result = PScan.pscan_forward(A, X)
ctx.save_for_backward(A_in, X_in, result)
return result
@staticmethod
def backward(ctx, grad_output):
A_in, X_in, result = ctx.saved_tensors
# For backward: reversed scan
# dA[t] = grad[t] * y[t-1], dX[t] = cumulative product of future A's * grad
# Simplified: use autograd-friendly sequential for backward (still fast enough)
B, D, L, N = A_in.size()
grad_A = torch.zeros_like(A_in)
grad_X = torch.zeros_like(X_in)
# Sequential backward (simpler, correct)
grad_h = torch.zeros(B, D, N, device=A_in.device, dtype=A_in.dtype)
for t in range(L - 1, -1, -1):
grad_h = grad_h + grad_output[:, :, t]
grad_X[:, :, t] = grad_h
if t > 0:
# y[t-1] from forward
y_prev = result[:, :, t - 1]
grad_A[:, :, t] = (grad_h * y_prev).sum(-1, keepdim=True).expand_as(A_in[:, :, t])
grad_h = grad_h * A_in[:, :, t]
else:
grad_A[:, :, 0] = torch.zeros_like(A_in[:, :, 0])
return grad_A, grad_X
pscan = PScan.apply
# ==============================================================================
# Selective State Space (S6) Block — The core Mamba mechanism
# ==============================================================================
class SelectiveSSM(nn.Module):
"""
Selective State Space Model (S6) from Mamba paper.
Uses parallel scan for O(L) computation without sequential loops.
For 2D images: we flatten H*W to sequence, process with SSM, reshape back.
"""
def __init__(self, d_model, d_state=16, d_conv=4, expand=2, use_parallel_scan=True):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(expand * d_model)
self.use_parallel_scan = use_parallel_scan
# Input projection: x → (xz) where x goes through SSM, z is gate
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
# 1D depthwise conv (local context before SSM)
self.conv1d = nn.Conv1d(
self.d_inner, self.d_inner,
kernel_size=d_conv, bias=True,
groups=self.d_inner, padding=d_conv - 1
)
# Input-dependent SSM parameters
self.x_proj = nn.Linear(self.d_inner, self.d_state * 2 + 1, bias=False)
self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
# A matrix (structured, log-parameterized)
A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).expand(self.d_inner, -1)
self.A_log = nn.Parameter(torch.log(A))
# D skip connection
self.D = nn.Parameter(torch.ones(self.d_inner))
# Output projection
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
# Pre-norm
self.norm = nn.RMSNorm(d_model)
def ssm_parallel(self, x):
"""Parallel scan SSM — no sequential loops."""
B_size, L, D = x.shape
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
D_skip = self.D.float()
# Compute input-dependent B, C, dt
x_dbl = self.x_proj(x) # (B, L, d_state*2 + 1)
dt, B_mat, C_mat = x_dbl.split([1, self.d_state, self.d_state], dim=-1)
dt = F.softplus(self.dt_proj(dt)) # (B, L, d_inner)
# Discretize: dA = exp(dt * A), dB = dt * B
dA = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)) # (B, L, D, N)
dBx = dt.unsqueeze(-1) * B_mat.unsqueeze(2) * x.unsqueeze(-1) # (B, L, D, N)
# Rearrange for parallel scan: (B, D, L, N)
dA = dA.permute(0, 2, 1, 3).contiguous()
dBx = dBx.permute(0, 2, 1, 3).contiguous()
if self.use_parallel_scan:
# Parallel prefix scan
h = pscan(dA, dBx) # (B, D, L, N)
else:
# Sequential fallback
h = torch.zeros_like(dBx)
state = torch.zeros(B_size, self.d_inner, self.d_state,
device=x.device, dtype=x.dtype)
for t in range(L):
state = dA[:, :, t] * state + dBx[:, :, t]
h[:, :, t] = state
# Output: y = C * h + D * x
h = h.permute(0, 2, 1, 3) # (B, L, D, N)
C_mat_exp = C_mat.unsqueeze(2) # (B, L, 1, N)
y = (h * C_mat_exp).sum(-1) # (B, L, D)
y = y + D_skip * x
return y
def forward(self, x):
"""x: (B, L, d_model)"""
residual = x
x = self.norm(x)
# Input projection + gate split
xz = self.in_proj(x) # (B, L, 2*d_inner)
x_ssm, z = xz.chunk(2, dim=-1)
# 1D conv for local context
x_ssm = rearrange(x_ssm, 'b l d -> b d l')
x_ssm = self.conv1d(x_ssm)[:, :, :residual.shape[1]]
x_ssm = rearrange(x_ssm, 'b d l -> b l d')
x_ssm = F.silu(x_ssm)
# SSM
y = self.ssm_parallel(x_ssm)
# Gated output
y = y * F.silu(z)
return self.out_proj(y) + residual
# ==============================================================================
# 2D Cross-Scan for Vision — VMamba style
# ==============================================================================
def cross_scan_2d(x):
"""
Convert 2D feature map to 4 directional 1D sequences.
x: (B, H, W, C)
Returns: list of 4 tensors, each (B, H*W, C)
"""
B, H, W, C = x.shape
# Direction 1: raster (top-left → bottom-right)
d1 = rearrange(x, 'b h w c -> b (h w) c')
# Direction 2: reverse raster
d2 = rearrange(x.flip([1, 2]), 'b h w c -> b (h w) c')
# Direction 3: column-first
d3 = rearrange(x.permute(0, 2, 1, 3), 'b w h c -> b (w h) c')
# Direction 4: reverse column-first
d4 = rearrange(x.permute(0, 2, 1, 3).flip([1, 2]), 'b w h c -> b (w h) c')
return [d1, d2, d3, d4]
def cross_merge_2d(ys, H, W):
"""
Merge 4 directional sequences back to 2D.
ys: list of 4 tensors (B, H*W, C)
Returns: (B, H, W, C)
"""
d1 = rearrange(ys[0], 'b (h w) c -> b h w c', h=H, w=W)
d2 = rearrange(ys[1], 'b (h w) c -> b h w c', h=H, w=W).flip([1, 2])
d3 = rearrange(ys[2], 'b (h w) c -> b w h c', h=H, w=W).permute(0, 2, 1, 3)
d4 = rearrange(ys[3], 'b (h w) c -> b w h c', h=H, w=W).permute(0, 2, 1, 3).flip([1, 2])
return (d1 + d2 + d3 + d4) * 0.25
class Mamba2DBlock(nn.Module):
"""
2D Mamba block using cross-scan pattern.
Processes feature maps with 4 directional SSM scans in parallel.
No attention — pure SSM + local conv.
"""
def __init__(self, channels, d_state=16, expand=2, use_parallel_scan=True):
super().__init__()
self.channels = channels
# One SSM shared across all 4 directions (weight sharing saves params)
self.ssm = SelectiveSSM(
d_model=channels,
d_state=d_state,
d_conv=4,
expand=expand,
use_parallel_scan=use_parallel_scan
)
self.mix_proj = nn.Linear(channels, channels)
self.norm = nn.RMSNorm(channels)
def forward(self, x):
"""x: (B, C, H, W)"""
B, C, H, W = x.shape
residual = x
# Convert to (B, H, W, C)
x_hwc = x.permute(0, 2, 3, 1)
# Cross-scan: 4 directional 1D sequences
seqs = cross_scan_2d(x_hwc)
# Process all 4 directions with shared SSM
outputs = [self.ssm(s) for s in seqs]
# Cross-merge back to 2D
merged = cross_merge_2d(outputs, H, W) # (B, H, W, C)
merged = self.norm(merged)
merged = self.mix_proj(merged)
# Back to (B, C, H, W)
return merged.permute(0, 3, 1, 2) + residual
# ==============================================================================
# Mobile Convolution Blocks
# ==============================================================================
class SqueezeExcitation(nn.Module):
"""Channel attention via squeeze-excitation."""
def __init__(self, channels, reduction=4):
super().__init__()
reduced = max(8, channels // reduction)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channels, reduced),
nn.SiLU(inplace=True),
nn.Linear(reduced, channels),
nn.Sigmoid()
)
def forward(self, x):
B, C, H, W = x.shape
w = self.pool(x).view(B, C)
w = self.fc(w).view(B, C, 1, 1)
return x * w
class FiLM(nn.Module):
"""Feature-wise Linear Modulation for style conditioning."""
def __init__(self, cond_dim, channels):
super().__init__()
self.proj = nn.Linear(cond_dim, channels * 2)
def forward(self, x, cond):
"""x: (B,C,H,W), cond: (B, cond_dim)"""
params = self.proj(cond) # (B, 2*C)
gamma, beta = params.chunk(2, dim=-1) # each (B, C)
gamma = gamma.view(-1, x.shape[1], 1, 1)
beta = beta.view(-1, x.shape[1], 1, 1)
return x * (1 + gamma) + beta
class MobileConvBlock(nn.Module):
"""
Mobile-friendly inverted residual block with:
- Depthwise separable convolution
- Squeeze-Excitation
- Optional FiLM style conditioning
- Reparameterizable for mobile deployment
"""
def __init__(self, in_ch, out_ch, expand_ratio=4, stride=1,
use_se=True, cond_dim=None):
super().__init__()
mid_ch = in_ch * expand_ratio
self.use_residual = (stride == 1 and in_ch == out_ch)
layers = []
# Expand
if expand_ratio != 1:
layers.extend([
nn.Conv2d(in_ch, mid_ch, 1, bias=False),
nn.BatchNorm2d(mid_ch),
nn.SiLU(inplace=True),
])
# Depthwise
layers.extend([
nn.Conv2d(mid_ch, mid_ch, 3, stride=stride, padding=1,
groups=mid_ch, bias=False),
nn.BatchNorm2d(mid_ch),
nn.SiLU(inplace=True),
])
self.conv = nn.Sequential(*layers)
# Squeeze-Excitation
self.se = SqueezeExcitation(mid_ch) if use_se else nn.Identity()
# Project
self.project = nn.Sequential(
nn.Conv2d(mid_ch, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch),
)
# FiLM conditioning
self.film = FiLM(cond_dim, out_ch) if cond_dim else None
# Skip connection
if not self.use_residual and stride == 1:
self.skip = nn.Conv2d(in_ch, out_ch, 1, bias=False)
elif not self.use_residual:
self.skip = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
nn.BatchNorm2d(out_ch),
)
else:
self.skip = nn.Identity()
def forward(self, x, cond=None):
out = self.conv(x)
out = self.se(out)
out = self.project(out)
if self.film is not None and cond is not None:
out = self.film(out, cond)
if self.use_residual:
return out + x
else:
return out + self.skip(x) if hasattr(self, 'skip') else out
class GatedConvBlock(nn.Module):
"""Gated convolution block — alternative to attention for global mixing."""
def __init__(self, channels):
super().__init__()
self.norm = nn.GroupNorm(min(32, channels), channels)
self.proj = nn.Conv2d(channels, channels * 2, 1)
self.dw = nn.Conv2d(channels, channels, 5, padding=2, groups=channels)
self.out = nn.Conv2d(channels, channels, 1)
def forward(self, x):
residual = x
x = self.norm(x)
gate, val = self.proj(x).chunk(2, dim=1)
val = self.dw(val)
x = val * F.silu(gate)
return self.out(x) + residual
# ==============================================================================
# PMA-VAE Encoder
# ==============================================================================
class PMAEncoder(nn.Module):
"""
Encoder with progressive downsampling:
H → H/2 → H/4 → H/8 → H/16
Outputs multi-scale latents:
- z_base: H/16 x W/16 x latent_base_dim
- z_detail: H/8 x W/8 x latent_detail_dim
- z_style: 1 x 1 x latent_style_dim (global)
"""
def __init__(self, in_channels=3,
stage_channels=(64, 128, 192, 256),
stage_blocks=(2, 2, 4, 4),
latent_base_dim=32,
latent_detail_dim=8,
latent_style_dim=128,
d_state=16,
use_parallel_scan=True):
super().__init__()
self.latent_base_dim = latent_base_dim
self.latent_detail_dim = latent_detail_dim
self.latent_style_dim = latent_style_dim
# Stem: PixelUnshuffle (lossless 2x downsample) + Conv
self.stem = nn.Sequential(
nn.PixelUnshuffle(2), # (B, C*4, H/2, W/2)
nn.Conv2d(in_channels * 4, stage_channels[0], 3, padding=1, bias=False),
nn.BatchNorm2d(stage_channels[0]),
nn.SiLU(inplace=True),
)
# Stage 1: H/2 → H/4, MobileConv only
self.stage1 = self._make_mobile_stage(
stage_channels[0], stage_channels[1], stage_blocks[0], stride=2
)
# Stage 2: H/4 → H/8, MobileConv + some Mamba
self.stage2 = self._make_hybrid_stage(
stage_channels[1], stage_channels[2], stage_blocks[1],
stride=2, d_state=d_state, mamba_ratio=0.5,
use_parallel_scan=use_parallel_scan
)
# Detail latent head (at H/8 resolution)
self.detail_head_mu = nn.Conv2d(stage_channels[2], latent_detail_dim, 1)
self.detail_head_logvar = nn.Conv2d(stage_channels[2], latent_detail_dim, 1)
# Stage 3: H/8 → H/16, Mamba-heavy
self.stage3 = self._make_hybrid_stage(
stage_channels[2], stage_channels[3], stage_blocks[2],
stride=2, d_state=d_state, mamba_ratio=0.75,
use_parallel_scan=use_parallel_scan
)
# One global mixing block at H/16
self.global_mix = GatedConvBlock(stage_channels[3])
# Base latent head (at H/16 resolution)
self.base_head_mu = nn.Conv2d(stage_channels[3], latent_base_dim, 1)
self.base_head_logvar = nn.Conv2d(stage_channels[3], latent_base_dim, 1)
# Style latent head (global)
self.style_pool = nn.AdaptiveAvgPool2d(1)
self.style_head_mu = nn.Linear(stage_channels[3], latent_style_dim)
self.style_head_logvar = nn.Linear(stage_channels[3], latent_style_dim)
def _make_mobile_stage(self, in_ch, out_ch, num_blocks, stride=1):
blocks = [MobileConvBlock(in_ch, out_ch, stride=stride)]
for _ in range(num_blocks - 1):
blocks.append(MobileConvBlock(out_ch, out_ch))
return nn.Sequential(*blocks)
def _make_hybrid_stage(self, in_ch, out_ch, num_blocks, stride=1,
d_state=16, mamba_ratio=0.5, use_parallel_scan=True):
blocks = nn.ModuleList()
# First block handles stride
blocks.append(MobileConvBlock(in_ch, out_ch, stride=stride))
num_mamba = max(1, int((num_blocks - 1) * mamba_ratio))
num_mobile = (num_blocks - 1) - num_mamba
for _ in range(num_mobile):
blocks.append(MobileConvBlock(out_ch, out_ch))
for _ in range(num_mamba):
blocks.append(Mamba2DBlock(out_ch, d_state=d_state, expand=2,
use_parallel_scan=use_parallel_scan))
return blocks
def forward(self, x):
"""
x: (B, 3, H, W)
Returns: dict with mu/logvar for base, detail, style latents
"""
# Stem: H → H/2
x = self.stem(x)
# Stage 1: H/2 → H/4
x = self.stage1(x)
# Stage 2: H/4 → H/8
for block in self.stage2:
x = block(x)
# Detail latent at H/8
detail_mu = self.detail_head_mu(x)
detail_logvar = self.detail_head_logvar(x)
# Stage 3: H/8 → H/16
for block in self.stage3:
x = block(x)
# Global mixing
x = self.global_mix(x)
# Base latent at H/16
base_mu = self.base_head_mu(x)
base_logvar = self.base_head_logvar(x)
# Style latent (global)
style_feat = self.style_pool(x).flatten(1)
style_mu = self.style_head_mu(style_feat)
style_logvar = self.style_head_logvar(style_feat)
return {
'base_mu': base_mu, 'base_logvar': base_logvar,
'detail_mu': detail_mu, 'detail_logvar': detail_logvar,
'style_mu': style_mu, 'style_logvar': style_logvar,
}
# ==============================================================================
# PMA-VAE Decoder
# ==============================================================================
class UpsampleBlock(nn.Module):
"""Efficient 2x upsample with pixel shuffle."""
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch * 4, 3, padding=1, bias=False)
self.ps = nn.PixelShuffle(2)
self.norm = nn.BatchNorm2d(out_ch)
self.act = nn.SiLU(inplace=True)
def forward(self, x):
return self.act(self.norm(self.ps(self.conv(x))))
class PMADecoder(nn.Module):
"""
Lightweight decoder for mobile deployment.
Takes multi-scale latents and reconstructs image:
z_base (H/16) + z_style → decode → fuse z_detail (H/8) → upsample → image
"""
def __init__(self, out_channels=3,
stage_channels=(256, 192, 128, 96, 64),
latent_base_dim=32,
latent_detail_dim=8,
latent_style_dim=128,
d_state=16,
use_parallel_scan=True):
super().__init__()
# Initial projection from latent to feature space
self.base_proj = nn.Sequential(
nn.Conv2d(latent_base_dim, stage_channels[0], 3, padding=1, bias=False),
nn.BatchNorm2d(stage_channels[0]),
nn.SiLU(inplace=True),
)
# Stage 1: H/16, Mamba blocks with FiLM style conditioning
self.stage1_blocks = nn.ModuleList([
MobileConvBlock(stage_channels[0], stage_channels[0],
cond_dim=latent_style_dim),
Mamba2DBlock(stage_channels[0], d_state=d_state,
use_parallel_scan=use_parallel_scan),
])
# Upsample H/16 → H/8
self.up1 = UpsampleBlock(stage_channels[0], stage_channels[1])
# Fuse detail latent at H/8
self.detail_fuse = nn.Sequential(
nn.Conv2d(stage_channels[1] + latent_detail_dim, stage_channels[1], 1, bias=False),
nn.BatchNorm2d(stage_channels[1]),
nn.SiLU(inplace=True),
)
# Stage 2: H/8, MobileConv with FiLM
self.stage2_blocks = nn.ModuleList([
MobileConvBlock(stage_channels[1], stage_channels[1],
cond_dim=latent_style_dim),
MobileConvBlock(stage_channels[1], stage_channels[1],
cond_dim=latent_style_dim),
Mamba2DBlock(stage_channels[1], d_state=d_state,
use_parallel_scan=use_parallel_scan),
])
# Upsample H/8 → H/4
self.up2 = UpsampleBlock(stage_channels[1], stage_channels[2])
# Stage 3: H/4
self.stage3_blocks = nn.ModuleList([
MobileConvBlock(stage_channels[2], stage_channels[2],
cond_dim=latent_style_dim),
MobileConvBlock(stage_channels[2], stage_channels[2]),
])
# Upsample H/4 → H/2
self.up3 = UpsampleBlock(stage_channels[2], stage_channels[3])
# Stage 4: H/2
self.stage4_blocks = nn.ModuleList([
MobileConvBlock(stage_channels[3], stage_channels[3]),
MobileConvBlock(stage_channels[3], stage_channels[3]),
])
# Upsample H/2 → H (PixelShuffle)
self.up4 = UpsampleBlock(stage_channels[3], stage_channels[4])
# Final output head
self.head = nn.Sequential(
nn.Conv2d(stage_channels[4], stage_channels[4], 3, padding=1),
nn.SiLU(inplace=True),
nn.Conv2d(stage_channels[4], out_channels, 3, padding=1),
nn.Tanh(), # output [-1, 1]
)
def forward(self, z_base, z_detail, z_style):
"""
z_base: (B, latent_base_dim, H/16, W/16)
z_detail: (B, latent_detail_dim, H/8, W/8)
z_style: (B, latent_style_dim)
"""
# Project base latent
x = self.base_proj(z_base)
# Stage 1: H/16 with style conditioning
for block in self.stage1_blocks:
if isinstance(block, MobileConvBlock):
x = block(x, cond=z_style)
else:
x = block(x)
# Upsample to H/8
x = self.up1(x)
# Fuse detail latent
x = self.detail_fuse(torch.cat([x, z_detail], dim=1))
# Stage 2: H/8
for block in self.stage2_blocks:
if isinstance(block, MobileConvBlock):
x = block(x, cond=z_style)
else:
x = block(x)
# Upsample to H/4
x = self.up2(x)
# Stage 3: H/4
for block in self.stage3_blocks:
if isinstance(block, MobileConvBlock):
x = block(x, cond=z_style)
else:
x = block(x)
# Upsample to H/2
x = self.up3(x)
# Stage 4: H/2
for block in self.stage4_blocks:
x = block(x)
# Upsample to H
x = self.up4(x)
# Output
return self.head(x)
# ==============================================================================
# Full PMA-VAE Model
# ==============================================================================
class PMAVAE(nn.Module):
"""
Parallel Mobile Artistic VAE — Full model.
Features:
- Attention-free (Mamba SSM + mobile convolutions)
- Multi-scale latent space (base + detail + style)
- FiLM style conditioning in decoder
- Parallel scan training (no sequential pixel loops)
- Mobile-deployable decoder (~15-20M params)
Args:
in_channels: Input image channels (3 for RGB)
enc_channels: Channel widths per encoder stage
dec_channels: Channel widths per decoder stage
latent_base_dim: Channels for H/16 base latent
latent_detail_dim: Channels for H/8 detail latent
latent_style_dim: Dimension of global style vector
d_state: SSM state dimension
use_parallel_scan: Use Blelloch parallel scan (True) or sequential (False)
"""
def __init__(self,
in_channels=3,
enc_channels=(64, 128, 192, 256),
dec_channels=(256, 192, 128, 96, 64),
enc_blocks=(2, 2, 4, 4),
latent_base_dim=32,
latent_detail_dim=8,
latent_style_dim=128,
d_state=16,
use_parallel_scan=True):
super().__init__()
self.encoder = PMAEncoder(
in_channels=in_channels,
stage_channels=enc_channels,
stage_blocks=enc_blocks,
latent_base_dim=latent_base_dim,
latent_detail_dim=latent_detail_dim,
latent_style_dim=latent_style_dim,
d_state=d_state,
use_parallel_scan=use_parallel_scan,
)
self.decoder = PMADecoder(
out_channels=in_channels,
stage_channels=dec_channels,
latent_base_dim=latent_base_dim,
latent_detail_dim=latent_detail_dim,
latent_style_dim=latent_style_dim,
d_state=d_state,
use_parallel_scan=use_parallel_scan,
)
def reparameterize(self, mu, logvar):
"""Reparameterization trick: z = mu + eps * std"""
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
return mu
def encode(self, x):
"""Encode image to multi-scale latent distributions."""
posteriors = self.encoder(x)
return posteriors
def decode(self, z_base, z_detail, z_style):
"""Decode latents to image."""
return self.decoder(z_base, z_detail, z_style)
def forward(self, x):
"""
Full forward pass: encode → sample → decode.
Returns: (recon, posteriors_dict)
"""
posteriors = self.encode(x)
# Sample from each latent distribution
z_base = self.reparameterize(posteriors['base_mu'], posteriors['base_logvar'])
z_detail = self.reparameterize(posteriors['detail_mu'], posteriors['detail_logvar'])
z_style = self.reparameterize(posteriors['style_mu'], posteriors['style_logvar'])
# Decode
recon = self.decode(z_base, z_detail, z_style)
return recon, posteriors
def get_last_decoder_layer(self):
"""For adaptive discriminator weight balancing."""
return self.decoder.head[-2].weight
@torch.no_grad()
def encode_to_latent(self, x):
"""Encode to deterministic latent (use mu, no sampling)."""
posteriors = self.encode(x)
return (posteriors['base_mu'], posteriors['detail_mu'], posteriors['style_mu'])
@torch.no_grad()
def decode_from_latent(self, z_base, z_detail, z_style):
"""Decode from latents (inference mode)."""
return self.decode(z_base, z_detail, z_style)
def count_parameters(self):
"""Count and display parameter breakdown."""
enc_params = sum(p.numel() for p in self.encoder.parameters())
dec_params = sum(p.numel() for p in self.decoder.parameters())
total = enc_params + dec_params
return {
'encoder': enc_params,
'decoder': dec_params,
'total': total,
'encoder_M': enc_params / 1e6,
'decoder_M': dec_params / 1e6,
'total_M': total / 1e6,
}
# ==============================================================================
# Model Configs
# ==============================================================================
def pmavae_tiny(**kwargs):
"""Tiny config for testing. ~5M params."""
return PMAVAE(
enc_channels=(32, 64, 96, 128),
dec_channels=(128, 96, 64, 48, 32),
enc_blocks=(1, 1, 2, 2),
latent_base_dim=16,
latent_detail_dim=4,
latent_style_dim=64,
d_state=8,
**kwargs
)
def pmavae_small(**kwargs):
"""Small config for Colab free tier. ~20M params."""
return PMAVAE(
enc_channels=(48, 96, 144, 192),
dec_channels=(192, 144, 96, 72, 48),
enc_blocks=(2, 2, 3, 3),
latent_base_dim=24,
latent_detail_dim=6,
latent_style_dim=96,
d_state=16,
**kwargs
)
def pmavae_base(**kwargs):
"""Base config. ~40M params."""
return PMAVAE(
enc_channels=(64, 128, 192, 256),
dec_channels=(256, 192, 128, 96, 64),
enc_blocks=(2, 2, 4, 4),
latent_base_dim=32,
latent_detail_dim=8,
latent_style_dim=128,
d_state=16,
**kwargs
)
if __name__ == '__main__':
# Quick test
device = 'cpu'
model = pmavae_tiny(use_parallel_scan=False).to(device)
x = torch.randn(2, 3, 256, 256, device=device)
recon, posteriors = model(x)
print(f"Input: {x.shape}")
print(f"Recon: {recon.shape}")
for k, v in posteriors.items():
print(f" {k}: {v.shape}")
params = model.count_parameters()
print(f"\nParams: {params['total_M']:.2f}M (enc: {params['encoder_M']:.2f}M, dec: {params['decoder_M']:.2f}M)")