Spaces:
Configuration error
Configuration error
| """ | |
| PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation | |
| Official implementation of the paper: | |
| "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" | |
| by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis | |
| Licensed under a modified MIT license | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import shutil | |
| from pathlib import Path | |
| from typing import Iterable, Optional, Sequence, Union | |
| HF_REPO_ID = "MLAdaptiveIntelligence/PRIMA" | |
| DEFAULT_HF_REPO_ID = HF_REPO_ID | |
| DEFAULT_STAGE1_CHECKPOINT = Path("data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt") | |
| DEFAULT_STAGE3_CHECKPOINT = Path("data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt") | |
| SMAL_ASSET_PATHS = [ | |
| "my_smpl_00781_4_all.pkl", | |
| "my_smpl_data_00781_4_all.pkl", | |
| "walking_toy_symmetric_pose_prior_with_cov_35parts.pkl", | |
| ] | |
| BACKBONE_ASSET_PATH = "amr_vitbb.pth" | |
| STAGE1_CONFIG_ASSET_PATH = "config_s1_HYDRA.yaml" | |
| STAGE1_CHECKPOINT_ASSET_PATH = "s1ckpt_inference.ckpt" | |
| STAGE3_CONFIG_ASSET_PATH = "config_s3_HYDRA.yaml" | |
| STAGE3_CHECKPOINT_ASSET_PATH = "s3ckpt_inference.ckpt" | |
| STAGE_ASSETS = { | |
| "PRIMAS1": (STAGE1_CONFIG_ASSET_PATH, STAGE1_CHECKPOINT_ASSET_PATH, "s1ckpt_inference.ckpt"), | |
| "PRIMAS3": (STAGE3_CONFIG_ASSET_PATH, STAGE3_CHECKPOINT_ASSET_PATH, "s3ckpt_inference.ckpt"), | |
| } | |
| STAGE_CHECKPOINTS = { | |
| "PRIMAS1": Path("PRIMAS1/checkpoints/s1ckpt_inference.ckpt"), | |
| "PRIMAS3": Path("PRIMAS3/checkpoints/s3ckpt_inference.ckpt"), | |
| } | |
| PathLike = Union[str, Path] | |
| def _resolve_hf_repo_id(hf_repo_id: Optional[str]) -> str: | |
| return hf_repo_id or os.environ.get("PRIMA_HF_REPO_ID", HF_REPO_ID) | |
| def _default_checkpoint_path(data_dir: PathLike = "data") -> Path: | |
| return Path(data_dir) / STAGE_CHECKPOINTS["PRIMAS1"] | |
| def _config_path_for_checkpoint(checkpoint_path: PathLike) -> Path: | |
| checkpoint_path = Path(checkpoint_path) | |
| return checkpoint_path.parent.parent / ".hydra" / "config.yaml" | |
| def _stage_for_checkpoint(checkpoint_path: PathLike) -> Optional[str]: | |
| checkpoint_path = Path(checkpoint_path) | |
| if len(checkpoint_path.parents) < 2: | |
| return None | |
| stage_name = checkpoint_path.parent.parent.name | |
| stage_assets = STAGE_ASSETS.get(stage_name) | |
| if stage_assets is None: | |
| return None | |
| _, _, checkpoint_name = stage_assets | |
| if checkpoint_path.name != checkpoint_name: | |
| return None | |
| return stage_name | |
| def _download_file( | |
| hf_repo_id: str, | |
| remote_filename: str, | |
| destination: Path, | |
| force_download: bool = False, | |
| ) -> None: | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| except ImportError: | |
| raise ImportError( | |
| "huggingface_hub is required to download PRIMA demo assets. " | |
| "Install it with: pip install huggingface_hub\n" | |
| "Or download the assets manually and pass a local checkpoint path." | |
| ) from None | |
| destination.parent.mkdir(parents=True, exist_ok=True) | |
| downloaded = hf_hub_download( | |
| repo_id=hf_repo_id, | |
| filename=remote_filename, | |
| local_dir=str(destination.parent), | |
| local_dir_use_symlinks=False, | |
| force_download=force_download, | |
| ) | |
| downloaded_path = Path(downloaded).resolve() | |
| target = destination.resolve() | |
| if downloaded_path != target: | |
| if target.exists(): | |
| target.unlink() | |
| shutil.move(str(downloaded_path), str(target)) | |
| def _validate_torch_checkpoint(path: Path) -> None: | |
| import inspect | |
| import pickle | |
| import zipfile | |
| import torch | |
| if zipfile.is_zipfile(path): | |
| with zipfile.ZipFile(path) as checkpoint_zip: | |
| corrupt_member = checkpoint_zip.testzip() | |
| if corrupt_member is not None: | |
| raise RuntimeError( | |
| f"Checkpoint file is invalid or incomplete: {path}\n" | |
| f"Corrupt archive member: {corrupt_member}\n" | |
| "Please redownload the checkpoint and try again." | |
| ) | |
| supports_weights_only = "weights_only" in inspect.signature(torch.load).parameters | |
| load_kwargs = {"map_location": "cpu"} | |
| if supports_weights_only: | |
| load_kwargs["weights_only"] = True | |
| try: | |
| torch.load(path, **load_kwargs) | |
| except pickle.UnpicklingError as exc: | |
| message = str(exc) | |
| if ( | |
| supports_weights_only | |
| and "Weights only load failed" in message | |
| and ("Unsupported global" in message or "Unsupported class" in message) | |
| ): | |
| return | |
| raise RuntimeError( | |
| f"Checkpoint file is invalid or incomplete: {path}\n" | |
| "Downloaded checkpoint is not loadable. " | |
| "Please verify the uploaded Hugging Face file and try again." | |
| ) from exc | |
| except Exception as exc: | |
| raise RuntimeError( | |
| f"Checkpoint file is invalid or incomplete: {path}\n" | |
| "Downloaded checkpoint is not loadable. " | |
| "Please verify the uploaded Hugging Face file and try again." | |
| ) from exc | |
| def _ensure_backbone(data_dir: Path, force: bool, hf_repo_id: str) -> None: | |
| target = data_dir / "amr_vitbb.pth" | |
| if target.exists() and not force: | |
| print(f"[skip] {target} already exists") | |
| return | |
| print("[download] pretrained backbone") | |
| _download_file(hf_repo_id, BACKBONE_ASSET_PATH, target, force_download=force) | |
| print(f"[ok] {target}") | |
| def _ensure_smal_assets(data_dir: Path, force: bool, hf_repo_id: str) -> None: | |
| required = [Path(p).name for p in SMAL_ASSET_PATHS] | |
| smal_dir = data_dir / "smal" | |
| if smal_dir.exists() and all((smal_dir / n).exists() for n in required) and not force: | |
| print("[skip] SMAL files already exist") | |
| return | |
| print("[download] SMAL assets") | |
| for asset_path in SMAL_ASSET_PATHS: | |
| target = smal_dir / Path(asset_path).name | |
| _download_file(hf_repo_id, asset_path, target, force_download=force) | |
| print(f"[ok] {smal_dir}") | |
| def _ensure_stage_assets( | |
| stage_name: str, | |
| data_dir: Path, | |
| force: bool, | |
| hf_repo_id: str, | |
| validate_existing: bool = True, | |
| ) -> None: | |
| if stage_name not in STAGE_ASSETS: | |
| known = ", ".join(sorted(STAGE_ASSETS)) | |
| raise ValueError(f"Unknown PRIMA stage '{stage_name}'. Expected one of: {known}") | |
| config_asset_path, checkpoint_asset_path, checkpoint_name = STAGE_ASSETS[stage_name] | |
| stage_dir = data_dir / stage_name | |
| config_target = stage_dir / ".hydra" / "config.yaml" | |
| checkpoint_target = stage_dir / "checkpoints" / checkpoint_name | |
| redownload_checkpoint = False | |
| if config_target.exists() and checkpoint_target.exists() and not force: | |
| if validate_existing: | |
| try: | |
| _validate_torch_checkpoint(checkpoint_target) | |
| except RuntimeError: | |
| print(f"[warn] {stage_name} checkpoint is incomplete, redownloading checkpoint only.") | |
| redownload_checkpoint = True | |
| else: | |
| print(f"[skip] {stage_name} assets already exist") | |
| return | |
| else: | |
| print(f"[skip] {stage_name} assets already exist") | |
| return | |
| print(f"[download] {stage_name} assets") | |
| config_target.parent.mkdir(parents=True, exist_ok=True) | |
| checkpoint_target.parent.mkdir(parents=True, exist_ok=True) | |
| if force or not config_target.exists(): | |
| _download_file(hf_repo_id, config_asset_path, config_target, force_download=force) | |
| if redownload_checkpoint and checkpoint_target.exists(): | |
| checkpoint_target.unlink() | |
| if force or redownload_checkpoint or not checkpoint_target.exists(): | |
| _download_file( | |
| hf_repo_id, | |
| checkpoint_asset_path, | |
| checkpoint_target, | |
| force_download=force or redownload_checkpoint, | |
| ) | |
| _validate_torch_checkpoint(checkpoint_target) | |
| print(f"[ok] {stage_dir}") | |
| def _normalize_stages(stages: Union[str, Iterable[str]]) -> Sequence[str]: | |
| if isinstance(stages, str): | |
| return (stages,) | |
| return tuple(stages) | |
| def _verify_assets(data_dir: Path, stages: Sequence[str]) -> None: | |
| required_paths = [ | |
| data_dir / "smal" / "my_smpl_00781_4_all.pkl", | |
| data_dir / "smal" / "my_smpl_data_00781_4_all.pkl", | |
| data_dir / "smal" / "walking_toy_symmetric_pose_prior_with_cov_35parts.pkl", | |
| data_dir / "amr_vitbb.pth", | |
| ] | |
| for stage_name in stages: | |
| if stage_name not in STAGE_ASSETS: | |
| known = ", ".join(sorted(STAGE_ASSETS)) | |
| raise ValueError(f"Unknown PRIMA stage '{stage_name}'. Expected one of: {known}") | |
| _, _, checkpoint_name = STAGE_ASSETS[stage_name] | |
| stage_dir = data_dir / stage_name | |
| required_paths.extend( | |
| [ | |
| stage_dir / ".hydra" / "config.yaml", | |
| stage_dir / "checkpoints" / checkpoint_name, | |
| ] | |
| ) | |
| missing = [p for p in required_paths if not p.exists()] | |
| if missing: | |
| raise FileNotFoundError("Missing required files:\n" + "\n".join(str(p) for p in missing)) | |
| for stage_name in stages: | |
| _, _, checkpoint_name = STAGE_ASSETS[stage_name] | |
| _validate_torch_checkpoint(data_dir / stage_name / "checkpoints" / checkpoint_name) | |
| def _ensure_assets_for_checkpoint( | |
| checkpoint_path: PathLike, | |
| force: bool = False, | |
| hf_repo_id: Optional[str] = None, | |
| ) -> None: | |
| checkpoint_path = Path(checkpoint_path) | |
| config_path = _config_path_for_checkpoint(checkpoint_path) | |
| stage_name = _stage_for_checkpoint(checkpoint_path) | |
| if stage_name is None: | |
| if checkpoint_path.exists() and config_path.exists() and not force: | |
| print(f"[skip] Using local PRIMA checkpoint {checkpoint_path}") | |
| return | |
| raise FileNotFoundError( | |
| "Missing checkpoint or config for a custom path:\n" | |
| f" checkpoint: {checkpoint_path}\n" | |
| f" config: {config_path}\n" | |
| "Auto-download supports the standard PRIMA demo layouts only:\n" | |
| " data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt\n" | |
| " data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt\n" | |
| "Pass one of those paths, or download/copy your custom checkpoint manually." | |
| ) | |
| data_dir = checkpoint_path.parent.parent.parent | |
| repo_id = _resolve_hf_repo_id(hf_repo_id) | |
| print(f"[download] Ensuring PRIMA demo assets under {data_dir}") | |
| _ensure_smal_assets(data_dir, force=force, hf_repo_id=repo_id) | |
| _ensure_backbone(data_dir, force=force, hf_repo_id=repo_id) | |
| _ensure_stage_assets( | |
| stage_name, | |
| data_dir, | |
| force=force, | |
| hf_repo_id=repo_id, | |
| validate_existing=False, | |
| ) | |
| def ensure_demo_assets( | |
| data_dir: PathLike = "data", | |
| *, | |
| stages: Union[str, Iterable[str]] = ("PRIMAS1",), | |
| force: bool = False, | |
| hf_repo_id: Optional[str] = None, | |
| ) -> None: | |
| """Ensure PRIMA demo assets exist in the expected ``data/`` layout.""" | |
| data_dir = Path(data_dir).resolve() | |
| data_dir.mkdir(parents=True, exist_ok=True) | |
| repo_id = _resolve_hf_repo_id(hf_repo_id) | |
| selected_stages = _normalize_stages(stages) | |
| _ensure_smal_assets(data_dir, force=force, hf_repo_id=repo_id) | |
| _ensure_backbone(data_dir, force=force, hf_repo_id=repo_id) | |
| for stage_name in selected_stages: | |
| _ensure_stage_assets(stage_name, data_dir, force=force, hf_repo_id=repo_id) | |
| _verify_assets(data_dir, selected_stages) | |
| def resolve_prima_checkpoint_path( | |
| checkpoint_path: PathLike = "", | |
| *, | |
| data_dir: PathLike = "data", | |
| auto_download: bool = True, | |
| hf_repo_id: Optional[str] = None, | |
| force: bool = False, | |
| ) -> str: | |
| """Return a PRIMA checkpoint path, downloading default demo assets if needed.""" | |
| resolved = Path(checkpoint_path) if checkpoint_path else _default_checkpoint_path(data_dir) | |
| if auto_download: | |
| _ensure_assets_for_checkpoint(resolved, force=force, hf_repo_id=hf_repo_id) | |
| return str(resolved) | |
| __all__ = [ | |
| "DEFAULT_HF_REPO_ID", | |
| "DEFAULT_STAGE1_CHECKPOINT", | |
| "DEFAULT_STAGE3_CHECKPOINT", | |
| "HF_REPO_ID", | |
| "ensure_demo_assets", | |
| "resolve_prima_checkpoint_path", | |
| ] | |