PRIMA-demo / app.py
HF Space deploy
Deploy snapshot (LFS for demo images per .gitattributes)
cdad419
"""
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
Official implementation of the paper:
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
Licensed under a modified MIT license
"""
"""Gradio demo for PRIMA + SuperAnimal + TTA.
This script wraps the ``demo_tta.py`` pipeline into an interactive
Gradio interface. The overall logic follows:
1. Given an input image, run Detectron2 to detect animals.
2. For each detected animal, run PRIMA for 3D pose/shape estimation.
3. Run the fine-tuned DeepLabCut SuperAnimal model to obtain PRIMA 26-keypoint
2D predictions.
4. Run test-time adaptation (TTA) with user-specified lr and iters.
5. Render and save before/after TTA results and keypoint visualizations.
"""
import argparse
import os
import sys
import tempfile
import traceback
from dataclasses import dataclass
from functools import lru_cache
from types import SimpleNamespace
from typing import List, Optional, Tuple
from pathlib import Path
import cv2
import gradio as gr
import numpy as np
import torch
import torch.utils.data
# Space demo on macOS: limit BLAS threads (PyRender + PyTorch on main thread only).
if sys.platform == "darwin" and os.environ.get("SPACE_ID"):
os.environ.setdefault("OMP_NUM_THREADS", "1")
torch.set_num_threads(1)
# Repo-local minimal ``chumpy`` shim (see ``chumpy/__init__.py``) so SMAL pickles load
# without installing the full chumpy package in Space builds.
_REPO_ROOT = Path(__file__).resolve().parent
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
from prima.utils.weights import (
DEFAULT_HF_REPO_ID,
resolve_prima_checkpoint_path,
)
from prima.utils.detection import select_animal_boxes
# Default checkpoint path following README instructions
DEFAULT_CHECKPOINT = str(_REPO_ROOT / "data" / "PRIMAS1" / "checkpoints" / "s1ckpt_inference.ckpt")
DEFAULT_HF_ASSET_REPO = DEFAULT_HF_REPO_ID
# Output folder for rendered images/meshes and keypoints
DEFAULT_OUT_FOLDER = "demo_out_tta_gradio"
_D2_R50_CFG = "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
_D2_R50_URL = (
"https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/"
"faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl"
)
_D2_X101_CFG = "COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml"
_D2_X101_URL = (
"https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/"
"faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl"
)
# Gradio example row: (image_rel, tta_lr, tta_iters, det_thresh, kp_thresh, side_view, save_mesh)
ExampleRow = Tuple[str, float, int, float, float, bool, bool]
@dataclass(frozen=True)
class DemoProfile:
"""Runtime settings for either the full local app or the lightweight HF Space demo."""
mode: str
prima_device: str # "auto" (CUDA if available) or "cpu"
detectron_config_yaml: str
detectron_weights_url: str
detectron_device: str # "auto" or "cpu"
default_tta_iters: int
max_tta_iters: int
default_save_mesh: bool
default_side_view: bool
preload_assets: bool
example_rows: Tuple[ExampleRow, ...]
description: str
interface_title: str
def resolve_prima_device(self) -> torch.device:
if self.prima_device == "cpu":
return torch.device("cpu")
return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def resolve_detectron_device(self) -> str:
if self.detectron_device == "cpu":
return "cpu"
return "cuda" if torch.cuda.is_available() else "cpu"
LOCAL_DEMO_PROFILE = DemoProfile(
mode="local",
prima_device="auto",
detectron_config_yaml=_D2_X101_CFG,
detectron_weights_url=_D2_X101_URL,
detectron_device="auto",
default_tta_iters=30,
max_tta_iters=100,
default_save_mesh=True,
default_side_view=False,
preload_assets=False,
example_rows=(
("demo_data/000000015956_horse.png", 1e-6, 30, 0.7, 0.1, False, True),
("demo_data/n02412080_12159.png", 1e-6, 30, 0.7, 0.1, False, True),
("demo_data/000000315905_zebra.jpg", 1e-6, 30, 0.7, 0.1, False, True),
("demo_data/beagle.jpg", 1e-6, 0, 0.7, 0.1, False, True),
("demo_data/shepherd_hati.jpg", 1e-6, 0, 0.7, 0.1, False, True),
),
description=(
"**Local demo** — full pipeline on your machine (GPU when available).\n\n"
"Detectron2 **X-101-FPN**, PRIMA mesh recovery, optional **DeepLabCut SuperAnimal + TTA**. "
"Set TTA iterations to **0** to skip adaptation. Outputs are saved under "
f"`{DEFAULT_OUT_FOLDER}`."
),
interface_title=(
"PRIMA local demo (GPU/CPU) — detection, mesh recovery, optional TTA"
),
)
SPACE_DEMO_PROFILE = DemoProfile(
mode="space",
prima_device="cpu",
detectron_config_yaml=_D2_R50_CFG,
detectron_weights_url=_D2_R50_URL,
detectron_device="cpu",
default_tta_iters=0,
max_tta_iters=30,
default_save_mesh=False,
default_side_view=False,
preload_assets=True,
example_rows=(
("demo_data/beagle.jpg", 1e-6, 0, 0.7, 0.1, False, False),
("demo_data/000000015956_horse.png", 1e-6, 0, 0.7, 0.1, False, False),
("demo_data/000000315905_zebra.jpg", 1e-6, 0, 0.7, 0.1, False, False),
),
description=(
"**Hugging Face Space (cpu-basic)** — lightweight demo: **CPU-only**, Detectron2 **R50-FPN**, "
"PRIMA inference. TTA is optional (0 by default; increases runtime). Mesh `.obj` export is off "
"by default to save time and disk."
),
interface_title="PRIMA on Hugging Face — lightweight CPU demo",
)
def _is_truthy_env(var_name: str) -> bool:
return os.environ.get(var_name, "").strip().lower() in {"1", "true", "yes", "on"}
def _running_on_space() -> bool:
return bool(os.environ.get("SPACE_ID") or os.environ.get("HF_SPACE_ID"))
@lru_cache(maxsize=1)
def get_demo_profile() -> DemoProfile:
"""Select local vs Space profile. Override with ``PRIMA_DEMO_MODE=local|space``."""
override = os.environ.get("PRIMA_DEMO_MODE", "").strip().lower()
if override == "local":
return LOCAL_DEMO_PROFILE
if override == "space":
return SPACE_DEMO_PROFILE
return SPACE_DEMO_PROFILE if _running_on_space() else LOCAL_DEMO_PROFILE
def _gradio_examples_for_interface(profile: DemoProfile) -> List[List]:
"""Gradio prefetches example media at startup (paths must exist beside ``app.py``)."""
if _is_truthy_env("PRIMA_DISABLE_GRADIO_EXAMPLES"):
return []
rows: List[List] = []
for rel, *rest in profile.example_rows:
p = _REPO_ROOT / rel
if p.is_file():
rows.append([str(p), *rest])
return rows
def _should_preload_assets(profile: DemoProfile) -> bool:
preload_env = os.environ.get("PRIMA_PRELOAD_ASSETS")
if preload_env is not None:
return _is_truthy_env("PRIMA_PRELOAD_ASSETS")
return profile.preload_assets
def _deeplabcut_available() -> bool:
try:
from deeplabcut.pose_estimation_pytorch.apis import superanimal_analyze_images # noqa: F401
return True
except Exception:
return False
def _preload_assets_once(checkpoint_path: str) -> None:
print("[startup] Ensuring demo assets from Hugging Face Hub...")
resolve_prima_checkpoint_path(
checkpoint_path,
data_dir=_REPO_ROOT / "data",
auto_download=True,
hf_repo_id=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_ASSET_REPO),
)
print("[startup] Asset preload complete.")
def _load_prima_model(checkpoint_path: str = DEFAULT_CHECKPOINT):
"""Load PRIMA model and renderer once for the Gradio app."""
from prima.models import load_prima
from prima.utils.renderer import Renderer, cam_crop_to_full
checkpoint_path = resolve_prima_checkpoint_path(
checkpoint_path,
data_dir=_REPO_ROOT / "data",
auto_download=True,
hf_repo_id=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_ASSET_REPO),
)
checkpoint = Path(checkpoint_path)
cfg_path = checkpoint.parent.parent / ".hydra" / "config.yaml"
if not checkpoint.exists():
raise FileNotFoundError(
f"Missing checkpoint: {checkpoint}. Download demo checkpoints/data as described in README."
)
if not cfg_path.exists():
raise FileNotFoundError(
f"Missing model config: {cfg_path}. Ensure the full checkpoint folder layout from README is present."
)
profile = get_demo_profile()
model, model_cfg = load_prima(checkpoint_path)
device = profile.resolve_prima_device()
model = model.to(device)
model.eval()
renderer = Renderer(model_cfg, faces=model.smal.faces)
return model, model_cfg, renderer, cam_crop_to_full, device
def _build_detector(profile: Optional[DemoProfile] = None):
"""Build Detectron2 animal detector (profile selects X-101+GPU locally vs R50+CPU on Space)."""
try:
import detectron2.config
import detectron2.engine
from detectron2 import model_zoo
except Exception as e:
print(f"[warn] Detectron2 unavailable ({type(e).__name__}: {e}); using full-image fallback bbox.")
return None
if profile is None:
profile = get_demo_profile()
config_yaml = profile.detectron_config_yaml
weights = profile.detectron_weights_url
device_str = profile.resolve_detectron_device()
print(f"[detectron2] mode={profile.mode} config={config_yaml} device={device_str}")
cfg = detectron2.config.get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(config_yaml))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.WEIGHTS = weights
cfg.MODEL.DEVICE = device_str
detector = detectron2.engine.DefaultPredictor(cfg)
return detector
def _load_model_and_detector_for_demo(checkpoint_path: str, profile: DemoProfile):
"""Load PRIMA and Detectron2 once for the Gradio session (main thread only)."""
model, model_cfg, renderer, cam_crop_to_full_fn, device = _load_prima_model(checkpoint_path)
detector = _build_detector(profile)
return model, model_cfg, renderer, cam_crop_to_full_fn, device, detector
def _detect_animal_boxes(
detector,
img_bgr: np.ndarray,
det_thresh: float,
) -> Optional[np.ndarray]:
"""Return Nx4 XYXY boxes or None if no animal detections."""
if detector is None:
h, w = img_bgr.shape[:2]
return np.array([[0.0, 0.0, float(max(1, w - 1)), float(max(1, h - 1))]], dtype=np.float32)
det_out = detector(img_bgr)
det_instances = det_out["instances"]
boxes, suppressed = select_animal_boxes(det_instances, score_threshold=float(det_thresh))
if suppressed > 0:
print(f"[INFO] Suppressed {suppressed} duplicate animal detection(s)")
if len(boxes) == 0:
return None
return boxes
# SuperAnimal defaults (same as in demo_tta parser)
SUPER_ANIMAL_ARGS = SimpleNamespace(
superanimal_name="superanimal_quadruped",
superanimal_model_name="hrnet_w32",
superanimal_detector_name="fasterrcnn_resnet50_fpn_v2",
superanimal_max_individuals=1,
saved_2d_model_path="",
pytorch_config_2d_path=str(_REPO_ROOT / "configs" / "sa_finetune_hrnet_w32.yaml"),
)
def _collect_animal_results(
model,
model_cfg,
renderer,
cam_crop_to_full_fn,
device,
detector,
out_folder: str,
img_rgb: np.ndarray,
tta_lr: float,
tta_num_iters: int,
det_thresh: float,
kp_conf_thresh: float,
side_view: bool,
save_mesh: bool,
boxes: Optional[np.ndarray] = None,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], str | None, str | None]:
"""Run detection + PRIMA + SuperAnimal + TTA on a single RGB image.
Returns:
before_imgs: list of HxWx3 RGB images (before TTA) for all animals
after_imgs: list of HxWx3 RGB images (after TTA) for all animals
kpt_imgs: list of HxWx3 RGB keypoint visualizations
first_before_mesh: path to first animal's before-TTA mesh (.obj) or None
first_after_mesh: path to first animal's after-TTA mesh (.obj) or None
"""
from prima.utils import recursive_to
from prima.datasets.vitdet_dataset import ViTDetDataset
from demo_tta import (
denorm_patch_to_rgb,
resolve_sa_weights_path,
run_superanimal_on_patch,
save_keypoint_vis,
tta_optimize,
)
if int(tta_num_iters) > 0 and not SUPER_ANIMAL_ARGS.saved_2d_model_path:
SUPER_ANIMAL_ARGS.saved_2d_model_path = resolve_sa_weights_path("")
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
if boxes is None:
boxes = _detect_animal_boxes(detector, img_bgr, det_thresh)
if boxes is None:
return [], [], [], None, None
dataset = ViTDetDataset(model_cfg, img_bgr, boxes)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
before_imgs: List[np.ndarray] = []
after_imgs: List[np.ndarray] = []
kpt_imgs: List[np.ndarray] = []
before_mesh_paths: List[str] = []
after_mesh_paths: List[str] = []
img_token = next(tempfile._get_candidate_names())
for batch in dataloader:
batch = recursive_to(batch, device)
with torch.no_grad():
out_before = model(batch)
animal_id = int(batch["animalid"][0])
# Save/render before TTA
img_fn = f"{img_token}"
from demo_tta import render_and_save # imported lazily to avoid circular issues
render_and_save(
renderer,
cam_crop_to_full_fn,
out_before,
batch,
img_fn,
animal_id,
out_folder,
suffix="before_tta",
side_view=side_view,
save_mesh=save_mesh,
)
before_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.png")
if os.path.exists(before_png_path):
before_bgr = cv2.imread(before_png_path)
if before_bgr is not None:
before_imgs.append(cv2.cvtColor(before_bgr, cv2.COLOR_BGR2RGB))
if save_mesh:
before_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.obj")
if os.path.exists(before_obj_path):
before_mesh_paths.append(before_obj_path)
if int(tta_num_iters) <= 0:
render_and_save(
renderer,
cam_crop_to_full_fn,
out_before,
batch,
img_fn,
animal_id,
out_folder,
suffix="after_tta",
side_view=side_view,
save_mesh=save_mesh,
)
after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png")
if os.path.exists(after_png_path):
after_bgr = cv2.imread(after_png_path)
if after_bgr is not None:
after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB))
if save_mesh:
after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj")
if os.path.exists(after_obj_path):
after_mesh_paths.append(after_obj_path)
continue
# Prepare patch for SuperAnimal
patch_rgb = denorm_patch_to_rgb(batch["img"][0])
with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir:
bodyparts_xyc = run_superanimal_on_patch(patch_rgb, SUPER_ANIMAL_ARGS, tmp_dir)
if bodyparts_xyc is None:
# No keypoints => skip TTA for this animal
continue
kpts_xyc = bodyparts_xyc
kpts_xyc[kpts_xyc[:, 2] < float(kp_conf_thresh), 2] = 0.0
# Save keypoint visualization and npy
kpt_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.png")
save_keypoint_vis(patch_rgb, kpts_xyc, kpt_png_path)
npy_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.npy")
np.save(npy_path, kpts_xyc)
if os.path.exists(kpt_png_path):
kpt_bgr = cv2.imread(kpt_png_path)
if kpt_bgr is not None:
kpt_imgs.append(cv2.cvtColor(kpt_bgr, cv2.COLOR_BGR2RGB))
# Normalize keypoints to [-0.5, 0.5] as in demo_tta
patch_h, patch_w = patch_rgb.shape[:2]
kpts_norm = kpts_xyc.copy()
kpts_norm[:, 0] = kpts_norm[:, 0] / float(patch_w) - 0.5
kpts_norm[:, 1] = kpts_norm[:, 1] / float(patch_h) - 0.5
gt_kpts_norm = torch.from_numpy(kpts_norm[None]).to(device=device, dtype=batch["img"].dtype)
# Run TTA
out_after = tta_optimize(
model,
batch,
gt_kpts_norm,
num_iters=int(tta_num_iters),
lr=float(tta_lr),
)
render_and_save(
renderer,
cam_crop_to_full_fn,
out_after,
batch,
img_fn,
animal_id,
out_folder,
suffix="after_tta",
side_view=side_view,
save_mesh=save_mesh,
)
after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png")
if os.path.exists(after_png_path):
after_bgr = cv2.imread(after_png_path)
if after_bgr is not None:
after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB))
if save_mesh:
after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj")
if os.path.exists(after_obj_path):
after_mesh_paths.append(after_obj_path)
first_before_mesh = before_mesh_paths[0] if before_mesh_paths else None
first_after_mesh = after_mesh_paths[0] if after_mesh_paths else None
return before_imgs, after_imgs, kpt_imgs, first_before_mesh, first_after_mesh
def build_demo(checkpoint_path: str = DEFAULT_CHECKPOINT, out_folder: str = DEFAULT_OUT_FOLDER) -> gr.Interface:
profile = get_demo_profile()
print(
f"[demo] profile={profile.mode} prima={profile.resolve_prima_device()} "
f"detectron={profile.detectron_config_yaml} d2_device={profile.resolve_detectron_device()}"
)
os.makedirs(out_folder, exist_ok=True)
runtime_cache = {
"model": None,
"model_cfg": None,
"renderer": None,
"cam_crop_to_full_fn": None,
"device": None,
"detector": None,
}
def gradio_inference(
image: np.ndarray,
tta_lr: float,
tta_num_iters: int,
det_thresh: float,
kp_conf_thresh: float,
side_view: bool,
save_mesh: bool,
):
"""Wrapper for Gradio. ``image`` is an RGB numpy array.
Yields intermediate status so long first-run (Hub downloads + model load)
and long inference do not hit silent client/proxy WebSocket timeouts.
"""
if image is None:
yield None, None, None, "No image provided."
return
if int(tta_num_iters) > 0 and not _deeplabcut_available():
yield (
None,
None,
None,
"DeepLabCut is not installed. Set **TTA iterations** to **0** for PRIMA-only inference, "
"or install `deeplabcut` (see README / requirements.txt).",
)
return
if image.dtype != np.uint8:
img_rgb = np.clip(image, 0, 255).astype(np.uint8)
else:
img_rgb = image
yield None, None, None, "Queued; preparing run…"
if runtime_cache["model"] is None:
yield (
None,
None,
None,
"First run: downloading demo assets from Hugging Face (large checkpoint) "
"and loading the model. This can take many minutes.",
)
try:
model, model_cfg, renderer, cam_crop_to_full_fn, device, detector = _load_model_and_detector_for_demo(
checkpoint_path, profile
)
except Exception:
yield None, None, None, f"Model initialization failed:\n{traceback.format_exc()}"
return
runtime_cache["model"] = model
runtime_cache["model_cfg"] = model_cfg
runtime_cache["renderer"] = renderer
runtime_cache["cam_crop_to_full_fn"] = cam_crop_to_full_fn
runtime_cache["device"] = device
runtime_cache["detector"] = detector
yield None, None, None, "Model loaded."
try:
yield None, None, None, "Running animal detection…"
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
boxes = _detect_animal_boxes(runtime_cache["detector"], img_bgr, det_thresh)
if boxes is None:
yield (
None,
None,
None,
"No animal detected. Try lowering the detection threshold or another image.",
)
return
yield (
None,
None,
None,
f"Detected {len(boxes)} animal region(s). Running PRIMA (+ SuperAnimal/TTA if enabled)…",
)
before_imgs, after_imgs, kpt_imgs, mesh_before, mesh_after = _collect_animal_results(
runtime_cache["model"],
runtime_cache["model_cfg"],
runtime_cache["renderer"],
runtime_cache["cam_crop_to_full_fn"],
runtime_cache["device"],
runtime_cache["detector"],
out_folder,
img_rgb,
tta_lr=tta_lr,
tta_num_iters=tta_num_iters,
det_thresh=det_thresh,
kp_conf_thresh=kp_conf_thresh,
side_view=side_view,
save_mesh=save_mesh,
boxes=boxes,
)
except Exception:
yield None, None, None, f"Inference failed:\n{traceback.format_exc()}"
return
first_before = before_imgs[0] if before_imgs else None
first_after = after_imgs[0] if after_imgs else None
first_kpts = kpt_imgs[0] if kpt_imgs else None
if first_before is None and first_after is None:
yield (
None,
None,
None,
"No output generated. Try an image with a clearly visible quadruped.",
)
return
yield first_before, first_after, first_kpts, "OK"
_gradio_examples = _gradio_examples_for_interface(profile)
_iface_kw = dict(
fn=gradio_inference,
analytics_enabled=False,
cache_examples=False,
inputs=[
gr.Image(
label="Input image",
type="numpy",
sources=["upload", "clipboard"],
),
gr.Slider(
label="TTA learning rate",
minimum=1e-7,
maximum=1e-4,
value=1e-6,
step=1e-7,
),
gr.Slider(
label="TTA iterations",
minimum=0,
maximum=profile.max_tta_iters,
value=profile.default_tta_iters,
step=1,
info="Set to 0 to disable TTA and reuse the initial PRIMA prediction.",
),
gr.Slider(
label="Detection threshold",
minimum=0.3,
maximum=0.9,
value=0.7,
step=0.05,
),
gr.Slider(
label="Keypoint confidence threshold",
minimum=0.0,
maximum=1.0,
value=0.1,
step=0.05,
),
gr.Checkbox(label="Render side view", value=profile.default_side_view),
gr.Checkbox(label="Save meshes (.obj)", value=profile.default_save_mesh),
],
outputs=[
gr.Image(label="Before TTA"),
gr.Image(label="After TTA"),
gr.Image(label="PRIMA 26 keypoints"),
gr.Textbox(label="Status / Traceback", lines=12),
],
title=profile.interface_title,
description=profile.description,
)
if _gradio_examples:
_iface_kw["examples"] = _gradio_examples
demo = gr.Interface(**_iface_kw)
demo.queue(max_size=8, default_concurrency_limit=1)
return demo
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Gradio demo for PRIMA + SuperAnimal + TTA")
parser.add_argument(
"--checkpoint",
type=str,
default=DEFAULT_CHECKPOINT,
help="Path to the pretrained PRIMA checkpoint",
)
parser.add_argument(
"--out_folder",
type=str,
default=DEFAULT_OUT_FOLDER,
help="Folder used to save rendered outputs and meshes",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
profile = get_demo_profile()
if _should_preload_assets(profile):
_preload_assets_once(args.checkpoint)
demo = build_demo(checkpoint_path=args.checkpoint, out_folder=args.out_folder)
demo.launch(inbrowser=False)