| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| from diffusers import AutoencoderKLWan |
|
|
|
|
| _REPO_ROOT = Path(__file__).resolve().parents[2] |
|
|
|
|
| def canonicalize_wan_vae_variant(value: str | None) -> str: |
| raw = str(value or "wan2_1").strip().lower().replace(".", "_").replace("-", "_") |
| if raw in {"wan21", "wan2_1"}: |
| return "wan2_1" |
| if raw in {"wan22", "wan2_2"}: |
| return "wan2_2" |
| raise ValueError(f"Unsupported wan_vae_variant={value!r}. Expected wan2_1 or wan2_2.") |
|
|
|
|
| def canonicalize_wan_vae_adaptation_mode(value: str | None) -> str: |
| raw = str(value or "auto").strip().lower().replace(".", "_").replace("-", "_") |
| if raw in {"auto", ""}: |
| return "auto" |
| if raw in {"none", "identity"}: |
| return "none" |
| if raw in {"reinit_io", "reinitio", "reinit"}: |
| return "reinit_io" |
| raise ValueError( |
| f"Unsupported wan_vae_adaptation_mode={value!r}. Expected auto, none, or reinit_io." |
| ) |
|
|
|
|
| def resolve_wan_vae_adaptation_mode( |
| *, |
| variant: str | None, |
| adaptation_mode: str | None, |
| transformer_channels: int, |
| vae_channels: int, |
| ) -> str: |
| variant_norm = canonicalize_wan_vae_variant(variant) |
| mode = canonicalize_wan_vae_adaptation_mode(adaptation_mode) |
| if mode != "auto": |
| return mode |
| if int(transformer_channels) == int(vae_channels): |
| return "none" |
| if variant_norm == "wan2_2": |
| return "reinit_io" |
| raise ValueError( |
| "Wan VAE/transformer latent channels do not match, but no supported adaptation path exists. " |
| f"variant={variant_norm} transformer_channels={int(transformer_channels)} vae_channels={int(vae_channels)}" |
| ) |
|
|
|
|
| @dataclass(frozen=True) |
| class WanVaeSource: |
| variant: str |
| load_root: Path |
| subfolder: str | None |
|
|
|
|
| def resolve_wan_vae_source( |
| *, |
| model_root: str | Path, |
| variant: str | None = None, |
| vae_root: str | Path | None = None, |
| ) -> WanVaeSource: |
| model_root_p = Path(model_root).expanduser().resolve() |
| variant_norm = canonicalize_wan_vae_variant(variant) |
|
|
| if vae_root: |
| candidate = Path(vae_root).expanduser().resolve() |
| elif variant_norm == "wan2_1": |
| candidate = model_root_p |
| else: |
| candidate = (_REPO_ROOT / "hugg_model" / "Wan2.2-TI2V-5B-Diffusers").resolve() |
|
|
| if (candidate / "vae" / "config.json").is_file(): |
| return WanVaeSource(variant=variant_norm, load_root=candidate, subfolder="vae") |
| if (candidate / "config.json").is_file(): |
| return WanVaeSource(variant=variant_norm, load_root=candidate, subfolder=None) |
|
|
| raise FileNotFoundError( |
| "Cannot resolve Wan VAE source. " |
| f"variant={variant_norm!r} model_root={str(model_root_p)!r} " |
| f"vae_root={None if vae_root is None else str(Path(vae_root).expanduser().resolve())!r} " |
| f"candidate={str(candidate)!r}" |
| ) |
|
|
|
|
| def load_wan_vae( |
| *, |
| source: WanVaeSource, |
| load_pretrained: bool, |
| ) -> AutoencoderKLWan: |
| if load_pretrained: |
| if source.subfolder: |
| return AutoencoderKLWan.from_pretrained(str(source.load_root), subfolder=source.subfolder) |
| return AutoencoderKLWan.from_pretrained(str(source.load_root)) |
|
|
| if source.subfolder: |
| vae_config = AutoencoderKLWan.load_config(str(source.load_root), subfolder=source.subfolder) |
| else: |
| vae_config = AutoencoderKLWan.load_config(str(source.load_root)) |
| return AutoencoderKLWan.from_config(vae_config) |
|
|