| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
|
|
| _REPO_ROOT = Path(__file__).resolve().parents[2] |
|
|
|
|
| def canonicalize_wan_model_variant(value: str | None) -> str: |
| raw = str(value or "wan2_1").strip().lower().replace(".", "_").replace("-", "_") |
| if raw in {"wan21", "wan2_1"}: |
| return "wan2_1" |
| if raw in {"wan22", "wan2_2"}: |
| return "wan2_2" |
| if raw in {"wan22_10l", "wan2_2_10l", "wan22_tiny", "wan2_2_tiny"}: |
| return "wan2_2_10l" |
| if raw in {"wan22_14l", "wan2_2_14l", "wan22_small", "wan2_2_small"}: |
| return "wan2_2_14l" |
| raise ValueError(f"Unsupported wan_model_variant={value!r}. Expected wan2_1, wan2_2, wan2_2_10l, or wan2_2_14l.") |
|
|
|
|
| def resolve_wan_transformer_num_layers( |
| *, |
| variant: str | None, |
| requested_override: int | None, |
| full_num_layers: int, |
| ) -> int: |
| variant_norm = canonicalize_wan_model_variant(variant) |
| override = None if requested_override is None else int(requested_override) |
| if override is not None and override > 0: |
| if override > int(full_num_layers): |
| raise ValueError( |
| f"wan_transformer_num_layers_override={override} exceeds full_num_layers={int(full_num_layers)}" |
| ) |
| return override |
| if variant_norm == "wan2_2_10l": |
| return min(10, int(full_num_layers)) |
| if variant_norm == "wan2_2_14l": |
| return min(14, int(full_num_layers)) |
| return int(full_num_layers) |
|
|
|
|
| def resolve_wan_transformer_block_indices( |
| *, |
| variant: str | None, |
| target_num_layers: int, |
| full_num_layers: int, |
| ) -> list[int]: |
| variant_norm = canonicalize_wan_model_variant(variant) |
| target = int(target_num_layers) |
| full = int(full_num_layers) |
| if target <= 0 or full <= 0: |
| raise ValueError(f"target_num_layers and full_num_layers must be positive, got {target}, {full}") |
| if target > full: |
| raise ValueError(f"target_num_layers={target} exceeds full_num_layers={full}") |
| if target == full: |
| return list(range(full)) |
| if variant_norm == "wan2_2_14l": |
| |
| |
| |
| if target == 1: |
| return [0] |
| if target == 2: |
| return [0, full - 1] |
| inner_count = target - 2 |
| inner = [] |
| if inner_count > 0: |
| inner = [ |
| int(round(i * (full - 1) / (target - 1))) |
| for i in range(1, target - 1) |
| ] |
| inner = [min(max(idx, 1), full - 2) for idx in inner] |
| indices = [0, *inner, full - 1] |
| if len(set(indices)) != target: |
| raise RuntimeError( |
| f"Failed to resolve unique uniformly spaced indices with pinned boundaries: {indices}" |
| ) |
| return indices |
| return list(range(target)) |
|
|
|
|
| def resolve_wan_model_root( |
| model_path: str | Path | None, |
| *, |
| variant: str | None = None, |
| ) -> Path: |
| variant_norm = canonicalize_wan_model_variant(variant) |
|
|
| if model_path: |
| path = Path(model_path).expanduser().resolve() |
| elif variant_norm in {"wan2_2", "wan2_2_10l", "wan2_2_14l"}: |
| path = (_REPO_ROOT / "hugg_model" / "Wan2.2-TI2V-5B-Diffusers").resolve() |
| else: |
| path = (_REPO_ROOT / "hugg_model" / "Wan2.1-T2V-1.3B-Diffusers").resolve() |
|
|
| if path.is_dir() and (path / "transformer").is_dir() and (path / "vae").is_dir(): |
| return path |
| if path.name in {"transformer", "vae", "scheduler", "tokenizer", "text_encoder"}: |
| candidate = path.parent |
| if (candidate / "transformer").is_dir() and (candidate / "vae").is_dir(): |
| return candidate |
| raise FileNotFoundError(f"Cannot resolve Wan diffusers root from: {path}") |
|
|