ckpt_debug_libero / wan_model_switch.py
jasonzhango's picture
Add files using upload-large-folder tool
36f9c79 verified
from __future__ import annotations
from pathlib import Path
_REPO_ROOT = Path(__file__).resolve().parents[2]
def canonicalize_wan_model_variant(value: str | None) -> str:
raw = str(value or "wan2_1").strip().lower().replace(".", "_").replace("-", "_")
if raw in {"wan21", "wan2_1"}:
return "wan2_1"
if raw in {"wan22", "wan2_2"}:
return "wan2_2"
if raw in {"wan22_10l", "wan2_2_10l", "wan22_tiny", "wan2_2_tiny"}:
return "wan2_2_10l"
if raw in {"wan22_14l", "wan2_2_14l", "wan22_small", "wan2_2_small"}:
return "wan2_2_14l"
raise ValueError(f"Unsupported wan_model_variant={value!r}. Expected wan2_1, wan2_2, wan2_2_10l, or wan2_2_14l.")
def resolve_wan_transformer_num_layers(
*,
variant: str | None,
requested_override: int | None,
full_num_layers: int,
) -> int:
variant_norm = canonicalize_wan_model_variant(variant)
override = None if requested_override is None else int(requested_override)
if override is not None and override > 0:
if override > int(full_num_layers):
raise ValueError(
f"wan_transformer_num_layers_override={override} exceeds full_num_layers={int(full_num_layers)}"
)
return override
if variant_norm == "wan2_2_10l":
return min(10, int(full_num_layers))
if variant_norm == "wan2_2_14l":
return min(14, int(full_num_layers))
return int(full_num_layers)
def resolve_wan_transformer_block_indices(
*,
variant: str | None,
target_num_layers: int,
full_num_layers: int,
) -> list[int]:
variant_norm = canonicalize_wan_model_variant(variant)
target = int(target_num_layers)
full = int(full_num_layers)
if target <= 0 or full <= 0:
raise ValueError(f"target_num_layers and full_num_layers must be positive, got {target}, {full}")
if target > full:
raise ValueError(f"target_num_layers={target} exceeds full_num_layers={full}")
if target == full:
return list(range(full))
if variant_norm == "wan2_2_14l":
# Keep early/mid/late depth coverage by selecting blocks uniformly across
# the full pretrained transformer, while explicitly pinning the first and
# last DiT blocks to preserve the input/output boundary behavior.
if target == 1:
return [0]
if target == 2:
return [0, full - 1]
inner_count = target - 2
inner = []
if inner_count > 0:
inner = [
int(round(i * (full - 1) / (target - 1)))
for i in range(1, target - 1)
]
inner = [min(max(idx, 1), full - 2) for idx in inner]
indices = [0, *inner, full - 1]
if len(set(indices)) != target:
raise RuntimeError(
f"Failed to resolve unique uniformly spaced indices with pinned boundaries: {indices}"
)
return indices
return list(range(target))
def resolve_wan_model_root(
model_path: str | Path | None,
*,
variant: str | None = None,
) -> Path:
variant_norm = canonicalize_wan_model_variant(variant)
if model_path:
path = Path(model_path).expanduser().resolve()
elif variant_norm in {"wan2_2", "wan2_2_10l", "wan2_2_14l"}:
path = (_REPO_ROOT / "hugg_model" / "Wan2.2-TI2V-5B-Diffusers").resolve()
else:
path = (_REPO_ROOT / "hugg_model" / "Wan2.1-T2V-1.3B-Diffusers").resolve()
if path.is_dir() and (path / "transformer").is_dir() and (path / "vae").is_dir():
return path
if path.name in {"transformer", "vae", "scheduler", "tokenizer", "text_encoder"}:
candidate = path.parent
if (candidate / "transformer").is_dir() and (candidate / "vae").is_dir():
return candidate
raise FileNotFoundError(f"Cannot resolve Wan diffusers root from: {path}")