ckpt_debug_libero / wan_vae_switch.py
jasonzhango's picture
Add files using upload-large-folder tool
36f9c79 verified
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from diffusers import AutoencoderKLWan
_REPO_ROOT = Path(__file__).resolve().parents[2]
def canonicalize_wan_vae_variant(value: str | None) -> str:
raw = str(value or "wan2_1").strip().lower().replace(".", "_").replace("-", "_")
if raw in {"wan21", "wan2_1"}:
return "wan2_1"
if raw in {"wan22", "wan2_2"}:
return "wan2_2"
raise ValueError(f"Unsupported wan_vae_variant={value!r}. Expected wan2_1 or wan2_2.")
def canonicalize_wan_vae_adaptation_mode(value: str | None) -> str:
raw = str(value or "auto").strip().lower().replace(".", "_").replace("-", "_")
if raw in {"auto", ""}:
return "auto"
if raw in {"none", "identity"}:
return "none"
if raw in {"reinit_io", "reinitio", "reinit"}:
return "reinit_io"
raise ValueError(
f"Unsupported wan_vae_adaptation_mode={value!r}. Expected auto, none, or reinit_io."
)
def resolve_wan_vae_adaptation_mode(
*,
variant: str | None,
adaptation_mode: str | None,
transformer_channels: int,
vae_channels: int,
) -> str:
variant_norm = canonicalize_wan_vae_variant(variant)
mode = canonicalize_wan_vae_adaptation_mode(adaptation_mode)
if mode != "auto":
return mode
if int(transformer_channels) == int(vae_channels):
return "none"
if variant_norm == "wan2_2":
return "reinit_io"
raise ValueError(
"Wan VAE/transformer latent channels do not match, but no supported adaptation path exists. "
f"variant={variant_norm} transformer_channels={int(transformer_channels)} vae_channels={int(vae_channels)}"
)
@dataclass(frozen=True)
class WanVaeSource:
variant: str
load_root: Path
subfolder: str | None
def resolve_wan_vae_source(
*,
model_root: str | Path,
variant: str | None = None,
vae_root: str | Path | None = None,
) -> WanVaeSource:
model_root_p = Path(model_root).expanduser().resolve()
variant_norm = canonicalize_wan_vae_variant(variant)
if vae_root:
candidate = Path(vae_root).expanduser().resolve()
elif variant_norm == "wan2_1":
candidate = model_root_p
else:
candidate = (_REPO_ROOT / "hugg_model" / "Wan2.2-TI2V-5B-Diffusers").resolve()
if (candidate / "vae" / "config.json").is_file():
return WanVaeSource(variant=variant_norm, load_root=candidate, subfolder="vae")
if (candidate / "config.json").is_file():
return WanVaeSource(variant=variant_norm, load_root=candidate, subfolder=None)
raise FileNotFoundError(
"Cannot resolve Wan VAE source. "
f"variant={variant_norm!r} model_root={str(model_root_p)!r} "
f"vae_root={None if vae_root is None else str(Path(vae_root).expanduser().resolve())!r} "
f"candidate={str(candidate)!r}"
)
def load_wan_vae(
*,
source: WanVaeSource,
load_pretrained: bool,
) -> AutoencoderKLWan:
if load_pretrained:
if source.subfolder:
return AutoencoderKLWan.from_pretrained(str(source.load_root), subfolder=source.subfolder)
return AutoencoderKLWan.from_pretrained(str(source.load_root))
if source.subfolder:
vae_config = AutoencoderKLWan.load_config(str(source.load_root), subfolder=source.subfolder)
else:
vae_config = AutoencoderKLWan.load_config(str(source.load_root))
return AutoencoderKLWan.from_config(vae_config)