from __future__ import annotations from pathlib import Path _REPO_ROOT = Path(__file__).resolve().parents[2] def canonicalize_wan_model_variant(value: str | None) -> str: raw = str(value or "wan2_1").strip().lower().replace(".", "_").replace("-", "_") if raw in {"wan21", "wan2_1"}: return "wan2_1" if raw in {"wan22", "wan2_2"}: return "wan2_2" if raw in {"wan22_10l", "wan2_2_10l", "wan22_tiny", "wan2_2_tiny"}: return "wan2_2_10l" if raw in {"wan22_14l", "wan2_2_14l", "wan22_small", "wan2_2_small"}: return "wan2_2_14l" raise ValueError(f"Unsupported wan_model_variant={value!r}. Expected wan2_1, wan2_2, wan2_2_10l, or wan2_2_14l.") def resolve_wan_transformer_num_layers( *, variant: str | None, requested_override: int | None, full_num_layers: int, ) -> int: variant_norm = canonicalize_wan_model_variant(variant) override = None if requested_override is None else int(requested_override) if override is not None and override > 0: if override > int(full_num_layers): raise ValueError( f"wan_transformer_num_layers_override={override} exceeds full_num_layers={int(full_num_layers)}" ) return override if variant_norm == "wan2_2_10l": return min(10, int(full_num_layers)) if variant_norm == "wan2_2_14l": return min(14, int(full_num_layers)) return int(full_num_layers) def resolve_wan_transformer_block_indices( *, variant: str | None, target_num_layers: int, full_num_layers: int, ) -> list[int]: variant_norm = canonicalize_wan_model_variant(variant) target = int(target_num_layers) full = int(full_num_layers) if target <= 0 or full <= 0: raise ValueError(f"target_num_layers and full_num_layers must be positive, got {target}, {full}") if target > full: raise ValueError(f"target_num_layers={target} exceeds full_num_layers={full}") if target == full: return list(range(full)) if variant_norm == "wan2_2_14l": # Keep early/mid/late depth coverage by selecting blocks uniformly across # the full pretrained transformer, while explicitly pinning the first and # last DiT blocks to preserve the input/output boundary behavior. if target == 1: return [0] if target == 2: return [0, full - 1] inner_count = target - 2 inner = [] if inner_count > 0: inner = [ int(round(i * (full - 1) / (target - 1))) for i in range(1, target - 1) ] inner = [min(max(idx, 1), full - 2) for idx in inner] indices = [0, *inner, full - 1] if len(set(indices)) != target: raise RuntimeError( f"Failed to resolve unique uniformly spaced indices with pinned boundaries: {indices}" ) return indices return list(range(target)) def resolve_wan_model_root( model_path: str | Path | None, *, variant: str | None = None, ) -> Path: variant_norm = canonicalize_wan_model_variant(variant) if model_path: path = Path(model_path).expanduser().resolve() elif variant_norm in {"wan2_2", "wan2_2_10l", "wan2_2_14l"}: path = (_REPO_ROOT / "hugg_model" / "Wan2.2-TI2V-5B-Diffusers").resolve() else: path = (_REPO_ROOT / "hugg_model" / "Wan2.1-T2V-1.3B-Diffusers").resolve() if path.is_dir() and (path / "transformer").is_dir() and (path / "vae").is_dir(): return path if path.name in {"transformer", "vae", "scheduler", "tokenizer", "text_encoder"}: candidate = path.parent if (candidate / "transformer").is_dir() and (candidate / "vae").is_dir(): return candidate raise FileNotFoundError(f"Cannot resolve Wan diffusers root from: {path}")