PRIMA-demo / prima /utils /weights.py
HF Space deploy
Deploy snapshot (LFS for demo images per .gitattributes)
cdad419
"""
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
Official implementation of the paper:
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
Licensed under a modified MIT license
"""
from __future__ import annotations
import os
import shutil
from pathlib import Path
from typing import Iterable, Optional, Sequence, Union
HF_REPO_ID = "MLAdaptiveIntelligence/PRIMA"
DEFAULT_HF_REPO_ID = HF_REPO_ID
DEFAULT_STAGE1_CHECKPOINT = Path("data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt")
DEFAULT_STAGE3_CHECKPOINT = Path("data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt")
SMAL_ASSET_PATHS = [
"my_smpl_00781_4_all.pkl",
"my_smpl_data_00781_4_all.pkl",
"walking_toy_symmetric_pose_prior_with_cov_35parts.pkl",
]
BACKBONE_ASSET_PATH = "amr_vitbb.pth"
STAGE1_CONFIG_ASSET_PATH = "config_s1_HYDRA.yaml"
STAGE1_CHECKPOINT_ASSET_PATH = "s1ckpt_inference.ckpt"
STAGE3_CONFIG_ASSET_PATH = "config_s3_HYDRA.yaml"
STAGE3_CHECKPOINT_ASSET_PATH = "s3ckpt_inference.ckpt"
STAGE_ASSETS = {
"PRIMAS1": (STAGE1_CONFIG_ASSET_PATH, STAGE1_CHECKPOINT_ASSET_PATH, "s1ckpt_inference.ckpt"),
"PRIMAS3": (STAGE3_CONFIG_ASSET_PATH, STAGE3_CHECKPOINT_ASSET_PATH, "s3ckpt_inference.ckpt"),
}
STAGE_CHECKPOINTS = {
"PRIMAS1": Path("PRIMAS1/checkpoints/s1ckpt_inference.ckpt"),
"PRIMAS3": Path("PRIMAS3/checkpoints/s3ckpt_inference.ckpt"),
}
PathLike = Union[str, Path]
def _resolve_hf_repo_id(hf_repo_id: Optional[str]) -> str:
return hf_repo_id or os.environ.get("PRIMA_HF_REPO_ID", HF_REPO_ID)
def _default_checkpoint_path(data_dir: PathLike = "data") -> Path:
return Path(data_dir) / STAGE_CHECKPOINTS["PRIMAS1"]
def _config_path_for_checkpoint(checkpoint_path: PathLike) -> Path:
checkpoint_path = Path(checkpoint_path)
return checkpoint_path.parent.parent / ".hydra" / "config.yaml"
def _stage_for_checkpoint(checkpoint_path: PathLike) -> Optional[str]:
checkpoint_path = Path(checkpoint_path)
if len(checkpoint_path.parents) < 2:
return None
stage_name = checkpoint_path.parent.parent.name
stage_assets = STAGE_ASSETS.get(stage_name)
if stage_assets is None:
return None
_, _, checkpoint_name = stage_assets
if checkpoint_path.name != checkpoint_name:
return None
return stage_name
def _download_file(
hf_repo_id: str,
remote_filename: str,
destination: Path,
force_download: bool = False,
) -> None:
try:
from huggingface_hub import hf_hub_download
except ImportError:
raise ImportError(
"huggingface_hub is required to download PRIMA demo assets. "
"Install it with: pip install huggingface_hub\n"
"Or download the assets manually and pass a local checkpoint path."
) from None
destination.parent.mkdir(parents=True, exist_ok=True)
downloaded = hf_hub_download(
repo_id=hf_repo_id,
filename=remote_filename,
local_dir=str(destination.parent),
local_dir_use_symlinks=False,
force_download=force_download,
)
downloaded_path = Path(downloaded).resolve()
target = destination.resolve()
if downloaded_path != target:
if target.exists():
target.unlink()
shutil.move(str(downloaded_path), str(target))
def _validate_torch_checkpoint(path: Path) -> None:
import inspect
import pickle
import zipfile
import torch
if zipfile.is_zipfile(path):
with zipfile.ZipFile(path) as checkpoint_zip:
corrupt_member = checkpoint_zip.testzip()
if corrupt_member is not None:
raise RuntimeError(
f"Checkpoint file is invalid or incomplete: {path}\n"
f"Corrupt archive member: {corrupt_member}\n"
"Please redownload the checkpoint and try again."
)
supports_weights_only = "weights_only" in inspect.signature(torch.load).parameters
load_kwargs = {"map_location": "cpu"}
if supports_weights_only:
load_kwargs["weights_only"] = True
try:
torch.load(path, **load_kwargs)
except pickle.UnpicklingError as exc:
message = str(exc)
if (
supports_weights_only
and "Weights only load failed" in message
and ("Unsupported global" in message or "Unsupported class" in message)
):
return
raise RuntimeError(
f"Checkpoint file is invalid or incomplete: {path}\n"
"Downloaded checkpoint is not loadable. "
"Please verify the uploaded Hugging Face file and try again."
) from exc
except Exception as exc:
raise RuntimeError(
f"Checkpoint file is invalid or incomplete: {path}\n"
"Downloaded checkpoint is not loadable. "
"Please verify the uploaded Hugging Face file and try again."
) from exc
def _ensure_backbone(data_dir: Path, force: bool, hf_repo_id: str) -> None:
target = data_dir / "amr_vitbb.pth"
if target.exists() and not force:
print(f"[skip] {target} already exists")
return
print("[download] pretrained backbone")
_download_file(hf_repo_id, BACKBONE_ASSET_PATH, target, force_download=force)
print(f"[ok] {target}")
def _ensure_smal_assets(data_dir: Path, force: bool, hf_repo_id: str) -> None:
required = [Path(p).name for p in SMAL_ASSET_PATHS]
smal_dir = data_dir / "smal"
if smal_dir.exists() and all((smal_dir / n).exists() for n in required) and not force:
print("[skip] SMAL files already exist")
return
print("[download] SMAL assets")
for asset_path in SMAL_ASSET_PATHS:
target = smal_dir / Path(asset_path).name
_download_file(hf_repo_id, asset_path, target, force_download=force)
print(f"[ok] {smal_dir}")
def _ensure_stage_assets(
stage_name: str,
data_dir: Path,
force: bool,
hf_repo_id: str,
validate_existing: bool = True,
) -> None:
if stage_name not in STAGE_ASSETS:
known = ", ".join(sorted(STAGE_ASSETS))
raise ValueError(f"Unknown PRIMA stage '{stage_name}'. Expected one of: {known}")
config_asset_path, checkpoint_asset_path, checkpoint_name = STAGE_ASSETS[stage_name]
stage_dir = data_dir / stage_name
config_target = stage_dir / ".hydra" / "config.yaml"
checkpoint_target = stage_dir / "checkpoints" / checkpoint_name
redownload_checkpoint = False
if config_target.exists() and checkpoint_target.exists() and not force:
if validate_existing:
try:
_validate_torch_checkpoint(checkpoint_target)
except RuntimeError:
print(f"[warn] {stage_name} checkpoint is incomplete, redownloading checkpoint only.")
redownload_checkpoint = True
else:
print(f"[skip] {stage_name} assets already exist")
return
else:
print(f"[skip] {stage_name} assets already exist")
return
print(f"[download] {stage_name} assets")
config_target.parent.mkdir(parents=True, exist_ok=True)
checkpoint_target.parent.mkdir(parents=True, exist_ok=True)
if force or not config_target.exists():
_download_file(hf_repo_id, config_asset_path, config_target, force_download=force)
if redownload_checkpoint and checkpoint_target.exists():
checkpoint_target.unlink()
if force or redownload_checkpoint or not checkpoint_target.exists():
_download_file(
hf_repo_id,
checkpoint_asset_path,
checkpoint_target,
force_download=force or redownload_checkpoint,
)
_validate_torch_checkpoint(checkpoint_target)
print(f"[ok] {stage_dir}")
def _normalize_stages(stages: Union[str, Iterable[str]]) -> Sequence[str]:
if isinstance(stages, str):
return (stages,)
return tuple(stages)
def _verify_assets(data_dir: Path, stages: Sequence[str]) -> None:
required_paths = [
data_dir / "smal" / "my_smpl_00781_4_all.pkl",
data_dir / "smal" / "my_smpl_data_00781_4_all.pkl",
data_dir / "smal" / "walking_toy_symmetric_pose_prior_with_cov_35parts.pkl",
data_dir / "amr_vitbb.pth",
]
for stage_name in stages:
if stage_name not in STAGE_ASSETS:
known = ", ".join(sorted(STAGE_ASSETS))
raise ValueError(f"Unknown PRIMA stage '{stage_name}'. Expected one of: {known}")
_, _, checkpoint_name = STAGE_ASSETS[stage_name]
stage_dir = data_dir / stage_name
required_paths.extend(
[
stage_dir / ".hydra" / "config.yaml",
stage_dir / "checkpoints" / checkpoint_name,
]
)
missing = [p for p in required_paths if not p.exists()]
if missing:
raise FileNotFoundError("Missing required files:\n" + "\n".join(str(p) for p in missing))
for stage_name in stages:
_, _, checkpoint_name = STAGE_ASSETS[stage_name]
_validate_torch_checkpoint(data_dir / stage_name / "checkpoints" / checkpoint_name)
def _ensure_assets_for_checkpoint(
checkpoint_path: PathLike,
force: bool = False,
hf_repo_id: Optional[str] = None,
) -> None:
checkpoint_path = Path(checkpoint_path)
config_path = _config_path_for_checkpoint(checkpoint_path)
stage_name = _stage_for_checkpoint(checkpoint_path)
if stage_name is None:
if checkpoint_path.exists() and config_path.exists() and not force:
print(f"[skip] Using local PRIMA checkpoint {checkpoint_path}")
return
raise FileNotFoundError(
"Missing checkpoint or config for a custom path:\n"
f" checkpoint: {checkpoint_path}\n"
f" config: {config_path}\n"
"Auto-download supports the standard PRIMA demo layouts only:\n"
" data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt\n"
" data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt\n"
"Pass one of those paths, or download/copy your custom checkpoint manually."
)
data_dir = checkpoint_path.parent.parent.parent
repo_id = _resolve_hf_repo_id(hf_repo_id)
print(f"[download] Ensuring PRIMA demo assets under {data_dir}")
_ensure_smal_assets(data_dir, force=force, hf_repo_id=repo_id)
_ensure_backbone(data_dir, force=force, hf_repo_id=repo_id)
_ensure_stage_assets(
stage_name,
data_dir,
force=force,
hf_repo_id=repo_id,
validate_existing=False,
)
def ensure_demo_assets(
data_dir: PathLike = "data",
*,
stages: Union[str, Iterable[str]] = ("PRIMAS1",),
force: bool = False,
hf_repo_id: Optional[str] = None,
) -> None:
"""Ensure PRIMA demo assets exist in the expected ``data/`` layout."""
data_dir = Path(data_dir).resolve()
data_dir.mkdir(parents=True, exist_ok=True)
repo_id = _resolve_hf_repo_id(hf_repo_id)
selected_stages = _normalize_stages(stages)
_ensure_smal_assets(data_dir, force=force, hf_repo_id=repo_id)
_ensure_backbone(data_dir, force=force, hf_repo_id=repo_id)
for stage_name in selected_stages:
_ensure_stage_assets(stage_name, data_dir, force=force, hf_repo_id=repo_id)
_verify_assets(data_dir, selected_stages)
def resolve_prima_checkpoint_path(
checkpoint_path: PathLike = "",
*,
data_dir: PathLike = "data",
auto_download: bool = True,
hf_repo_id: Optional[str] = None,
force: bool = False,
) -> str:
"""Return a PRIMA checkpoint path, downloading default demo assets if needed."""
resolved = Path(checkpoint_path) if checkpoint_path else _default_checkpoint_path(data_dir)
if auto_download:
_ensure_assets_for_checkpoint(resolved, force=force, hf_repo_id=hf_repo_id)
return str(resolved)
__all__ = [
"DEFAULT_HF_REPO_ID",
"DEFAULT_STAGE1_CHECKPOINT",
"DEFAULT_STAGE3_CHECKPOINT",
"HF_REPO_ID",
"ensure_demo_assets",
"resolve_prima_checkpoint_path",
]