""" PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation Official implementation of the paper: "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis Licensed under a modified MIT license """ from __future__ import annotations import os import shutil from pathlib import Path from typing import Iterable, Optional, Sequence, Union HF_REPO_ID = "MLAdaptiveIntelligence/PRIMA" DEFAULT_HF_REPO_ID = HF_REPO_ID DEFAULT_STAGE1_CHECKPOINT = Path("data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt") DEFAULT_STAGE3_CHECKPOINT = Path("data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt") SMAL_ASSET_PATHS = [ "my_smpl_00781_4_all.pkl", "my_smpl_data_00781_4_all.pkl", "walking_toy_symmetric_pose_prior_with_cov_35parts.pkl", ] BACKBONE_ASSET_PATH = "amr_vitbb.pth" STAGE1_CONFIG_ASSET_PATH = "config_s1_HYDRA.yaml" STAGE1_CHECKPOINT_ASSET_PATH = "s1ckpt_inference.ckpt" STAGE3_CONFIG_ASSET_PATH = "config_s3_HYDRA.yaml" STAGE3_CHECKPOINT_ASSET_PATH = "s3ckpt_inference.ckpt" STAGE_ASSETS = { "PRIMAS1": (STAGE1_CONFIG_ASSET_PATH, STAGE1_CHECKPOINT_ASSET_PATH, "s1ckpt_inference.ckpt"), "PRIMAS3": (STAGE3_CONFIG_ASSET_PATH, STAGE3_CHECKPOINT_ASSET_PATH, "s3ckpt_inference.ckpt"), } STAGE_CHECKPOINTS = { "PRIMAS1": Path("PRIMAS1/checkpoints/s1ckpt_inference.ckpt"), "PRIMAS3": Path("PRIMAS3/checkpoints/s3ckpt_inference.ckpt"), } PathLike = Union[str, Path] def _resolve_hf_repo_id(hf_repo_id: Optional[str]) -> str: return hf_repo_id or os.environ.get("PRIMA_HF_REPO_ID", HF_REPO_ID) def _default_checkpoint_path(data_dir: PathLike = "data") -> Path: return Path(data_dir) / STAGE_CHECKPOINTS["PRIMAS1"] def _config_path_for_checkpoint(checkpoint_path: PathLike) -> Path: checkpoint_path = Path(checkpoint_path) return checkpoint_path.parent.parent / ".hydra" / "config.yaml" def _stage_for_checkpoint(checkpoint_path: PathLike) -> Optional[str]: checkpoint_path = Path(checkpoint_path) if len(checkpoint_path.parents) < 2: return None stage_name = checkpoint_path.parent.parent.name stage_assets = STAGE_ASSETS.get(stage_name) if stage_assets is None: return None _, _, checkpoint_name = stage_assets if checkpoint_path.name != checkpoint_name: return None return stage_name def _download_file( hf_repo_id: str, remote_filename: str, destination: Path, force_download: bool = False, ) -> None: try: from huggingface_hub import hf_hub_download except ImportError: raise ImportError( "huggingface_hub is required to download PRIMA demo assets. " "Install it with: pip install huggingface_hub\n" "Or download the assets manually and pass a local checkpoint path." ) from None destination.parent.mkdir(parents=True, exist_ok=True) downloaded = hf_hub_download( repo_id=hf_repo_id, filename=remote_filename, local_dir=str(destination.parent), local_dir_use_symlinks=False, force_download=force_download, ) downloaded_path = Path(downloaded).resolve() target = destination.resolve() if downloaded_path != target: if target.exists(): target.unlink() shutil.move(str(downloaded_path), str(target)) def _validate_torch_checkpoint(path: Path) -> None: import inspect import pickle import zipfile import torch if zipfile.is_zipfile(path): with zipfile.ZipFile(path) as checkpoint_zip: corrupt_member = checkpoint_zip.testzip() if corrupt_member is not None: raise RuntimeError( f"Checkpoint file is invalid or incomplete: {path}\n" f"Corrupt archive member: {corrupt_member}\n" "Please redownload the checkpoint and try again." ) supports_weights_only = "weights_only" in inspect.signature(torch.load).parameters load_kwargs = {"map_location": "cpu"} if supports_weights_only: load_kwargs["weights_only"] = True try: torch.load(path, **load_kwargs) except pickle.UnpicklingError as exc: message = str(exc) if ( supports_weights_only and "Weights only load failed" in message and ("Unsupported global" in message or "Unsupported class" in message) ): return raise RuntimeError( f"Checkpoint file is invalid or incomplete: {path}\n" "Downloaded checkpoint is not loadable. " "Please verify the uploaded Hugging Face file and try again." ) from exc except Exception as exc: raise RuntimeError( f"Checkpoint file is invalid or incomplete: {path}\n" "Downloaded checkpoint is not loadable. " "Please verify the uploaded Hugging Face file and try again." ) from exc def _ensure_backbone(data_dir: Path, force: bool, hf_repo_id: str) -> None: target = data_dir / "amr_vitbb.pth" if target.exists() and not force: print(f"[skip] {target} already exists") return print("[download] pretrained backbone") _download_file(hf_repo_id, BACKBONE_ASSET_PATH, target, force_download=force) print(f"[ok] {target}") def _ensure_smal_assets(data_dir: Path, force: bool, hf_repo_id: str) -> None: required = [Path(p).name for p in SMAL_ASSET_PATHS] smal_dir = data_dir / "smal" if smal_dir.exists() and all((smal_dir / n).exists() for n in required) and not force: print("[skip] SMAL files already exist") return print("[download] SMAL assets") for asset_path in SMAL_ASSET_PATHS: target = smal_dir / Path(asset_path).name _download_file(hf_repo_id, asset_path, target, force_download=force) print(f"[ok] {smal_dir}") def _ensure_stage_assets( stage_name: str, data_dir: Path, force: bool, hf_repo_id: str, validate_existing: bool = True, ) -> None: if stage_name not in STAGE_ASSETS: known = ", ".join(sorted(STAGE_ASSETS)) raise ValueError(f"Unknown PRIMA stage '{stage_name}'. Expected one of: {known}") config_asset_path, checkpoint_asset_path, checkpoint_name = STAGE_ASSETS[stage_name] stage_dir = data_dir / stage_name config_target = stage_dir / ".hydra" / "config.yaml" checkpoint_target = stage_dir / "checkpoints" / checkpoint_name redownload_checkpoint = False if config_target.exists() and checkpoint_target.exists() and not force: if validate_existing: try: _validate_torch_checkpoint(checkpoint_target) except RuntimeError: print(f"[warn] {stage_name} checkpoint is incomplete, redownloading checkpoint only.") redownload_checkpoint = True else: print(f"[skip] {stage_name} assets already exist") return else: print(f"[skip] {stage_name} assets already exist") return print(f"[download] {stage_name} assets") config_target.parent.mkdir(parents=True, exist_ok=True) checkpoint_target.parent.mkdir(parents=True, exist_ok=True) if force or not config_target.exists(): _download_file(hf_repo_id, config_asset_path, config_target, force_download=force) if redownload_checkpoint and checkpoint_target.exists(): checkpoint_target.unlink() if force or redownload_checkpoint or not checkpoint_target.exists(): _download_file( hf_repo_id, checkpoint_asset_path, checkpoint_target, force_download=force or redownload_checkpoint, ) _validate_torch_checkpoint(checkpoint_target) print(f"[ok] {stage_dir}") def _normalize_stages(stages: Union[str, Iterable[str]]) -> Sequence[str]: if isinstance(stages, str): return (stages,) return tuple(stages) def _verify_assets(data_dir: Path, stages: Sequence[str]) -> None: required_paths = [ data_dir / "smal" / "my_smpl_00781_4_all.pkl", data_dir / "smal" / "my_smpl_data_00781_4_all.pkl", data_dir / "smal" / "walking_toy_symmetric_pose_prior_with_cov_35parts.pkl", data_dir / "amr_vitbb.pth", ] for stage_name in stages: if stage_name not in STAGE_ASSETS: known = ", ".join(sorted(STAGE_ASSETS)) raise ValueError(f"Unknown PRIMA stage '{stage_name}'. Expected one of: {known}") _, _, checkpoint_name = STAGE_ASSETS[stage_name] stage_dir = data_dir / stage_name required_paths.extend( [ stage_dir / ".hydra" / "config.yaml", stage_dir / "checkpoints" / checkpoint_name, ] ) missing = [p for p in required_paths if not p.exists()] if missing: raise FileNotFoundError("Missing required files:\n" + "\n".join(str(p) for p in missing)) for stage_name in stages: _, _, checkpoint_name = STAGE_ASSETS[stage_name] _validate_torch_checkpoint(data_dir / stage_name / "checkpoints" / checkpoint_name) def _ensure_assets_for_checkpoint( checkpoint_path: PathLike, force: bool = False, hf_repo_id: Optional[str] = None, ) -> None: checkpoint_path = Path(checkpoint_path) config_path = _config_path_for_checkpoint(checkpoint_path) stage_name = _stage_for_checkpoint(checkpoint_path) if stage_name is None: if checkpoint_path.exists() and config_path.exists() and not force: print(f"[skip] Using local PRIMA checkpoint {checkpoint_path}") return raise FileNotFoundError( "Missing checkpoint or config for a custom path:\n" f" checkpoint: {checkpoint_path}\n" f" config: {config_path}\n" "Auto-download supports the standard PRIMA demo layouts only:\n" " data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt\n" " data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt\n" "Pass one of those paths, or download/copy your custom checkpoint manually." ) data_dir = checkpoint_path.parent.parent.parent repo_id = _resolve_hf_repo_id(hf_repo_id) print(f"[download] Ensuring PRIMA demo assets under {data_dir}") _ensure_smal_assets(data_dir, force=force, hf_repo_id=repo_id) _ensure_backbone(data_dir, force=force, hf_repo_id=repo_id) _ensure_stage_assets( stage_name, data_dir, force=force, hf_repo_id=repo_id, validate_existing=False, ) def ensure_demo_assets( data_dir: PathLike = "data", *, stages: Union[str, Iterable[str]] = ("PRIMAS1",), force: bool = False, hf_repo_id: Optional[str] = None, ) -> None: """Ensure PRIMA demo assets exist in the expected ``data/`` layout.""" data_dir = Path(data_dir).resolve() data_dir.mkdir(parents=True, exist_ok=True) repo_id = _resolve_hf_repo_id(hf_repo_id) selected_stages = _normalize_stages(stages) _ensure_smal_assets(data_dir, force=force, hf_repo_id=repo_id) _ensure_backbone(data_dir, force=force, hf_repo_id=repo_id) for stage_name in selected_stages: _ensure_stage_assets(stage_name, data_dir, force=force, hf_repo_id=repo_id) _verify_assets(data_dir, selected_stages) def resolve_prima_checkpoint_path( checkpoint_path: PathLike = "", *, data_dir: PathLike = "data", auto_download: bool = True, hf_repo_id: Optional[str] = None, force: bool = False, ) -> str: """Return a PRIMA checkpoint path, downloading default demo assets if needed.""" resolved = Path(checkpoint_path) if checkpoint_path else _default_checkpoint_path(data_dir) if auto_download: _ensure_assets_for_checkpoint(resolved, force=force, hf_repo_id=hf_repo_id) return str(resolved) __all__ = [ "DEFAULT_HF_REPO_ID", "DEFAULT_STAGE1_CHECKPOINT", "DEFAULT_STAGE3_CHECKPOINT", "HF_REPO_ID", "ensure_demo_assets", "resolve_prima_checkpoint_path", ]