File size: 2,899 Bytes
720fb6d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | """Frozen model architecture and user-tunable inference configuration."""
from __future__ import annotations
import json
from dataclasses import asdict, dataclass
from pathlib import Path
@dataclass(frozen=True)
class FCDMDiffAEConfig:
"""Frozen model architecture config. Stored alongside weights as config.json."""
in_channels: int = 3
patch_size: int = 16
model_dim: int = 896
encoder_depth: int = 4
decoder_depth: int = 8
decoder_start_blocks: int = 2
decoder_end_blocks: int = 2
bottleneck_dim: int = 128
mlp_ratio: float = 4.0
depthwise_kernel_size: int = 7
adaln_low_rank_rank: int = 128
# Encoder posterior kind: "diagonal_gaussian" or "deterministic"
bottleneck_posterior_kind: str = "diagonal_gaussian"
# Post-bottleneck normalization: "channel_wise" or "disabled"
bottleneck_norm_mode: str = "disabled"
# Bottleneck patchification: "off" or "patch_2x2"
# When "patch_2x2", encoder latents are 2x2 patchified after the bottleneck
# (channels * 4, spatial / 2), and decode unpatchifies before the decoder.
bottleneck_patchify_mode: str = "off"
# VP diffusion schedule endpoints
logsnr_min: float = -10.0
logsnr_max: float = 10.0
# Pixel-space noise std for VP diffusion initialization
pixel_noise_std: float = 0.558
@property
def latent_channels(self) -> int:
"""Channel width of the exported latent space."""
if self.bottleneck_patchify_mode == "patch_2x2":
return self.bottleneck_dim * 4
return self.bottleneck_dim
@property
def effective_patch_size(self) -> int:
"""Effective spatial stride from image to latent grid."""
if self.bottleneck_patchify_mode == "patch_2x2":
return self.patch_size * 2
return self.patch_size
def save(self, path: str | Path) -> None:
"""Save config as JSON."""
p = Path(path)
p.parent.mkdir(parents=True, exist_ok=True)
p.write_text(json.dumps(asdict(self), indent=2) + "\n")
@classmethod
def load(cls, path: str | Path) -> FCDMDiffAEConfig:
"""Load config from JSON."""
data = json.loads(Path(path).read_text())
return cls(**data)
@dataclass
class FCDMDiffAEInferenceConfig:
"""User-tunable inference parameters with sensible defaults.
PDG (Path-Drop Guidance) sharpens reconstructions by degrading conditioning
in one pass and amplifying the difference. When enabled, uses 2 NFE per step.
Recommended: ``pdg=True, pdg_strength=2.0, num_steps=10``.
"""
num_steps: int = 1 # number of denoising steps (NFE)
sampler: str = "ddim" # "ddim" or "dpmpp_2m"
schedule: str = "linear" # "linear" or "cosine"
pdg: bool = False # enable PDG for perceptual sharpening
pdg_strength: float = 2.0 # CFG-like strength when pdg=True
seed: int | None = None
|