VLAlert / training /Policy /make_belief_cache.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
4.46 kB
#!/usr/bin/env python3
"""
Pre-compute and cache belief vectors for all policy label windows.
Since the SFT backbone is fully frozen, belief[i] = SFTModel(frames[i]) is
deterministic. Computing it once and saving it eliminates the 3B-param VLM
forward pass from every training step, making PolicyHead training ~1000× faster.
Output:
data/belief_cache/train.pt — tensors for all train samples
data/belief_cache/val.pt — tensors for all val samples
Cache format (per split):
{
"beliefs": FloatTensor [N, hidden_dim] (float32)
"tta_means": FloatTensor [N]
"tta_vars": FloatTensor [N]
}
Indices match exactly the sample order in data/policy_labels/{split}.json.
Usage:
cd PROJECT_ROOT
python -m training.Policy.make_belief_cache \
--sft_checkpoint checkpoints/SFT/sft_v2/best \
--label_dir data/policy_labels \
--out_dir data/belief_cache \
--batch_size 8
"""
from __future__ import annotations
import argparse
import json
import logging
from pathlib import Path
from typing import List
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import sys
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
from .policy_model import PolicyModel
from .policy_dataset import PolicyDataset, policy_collate_fn
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("Policy.make_cache")
@torch.no_grad()
def build_cache(
model: PolicyModel,
loader: DataLoader,
split_name: str,
) -> dict:
"""Run VLM on all samples, collect belief + tta statistics."""
model.eval()
all_beliefs: List[torch.Tensor] = []
all_tta_means: List[torch.Tensor] = []
all_tta_vars: List[torch.Tensor] = []
for batch in tqdm(loader, desc=f" Caching {split_name}"):
inputs = model._build_inputs(batch["images"], batch["metadata"])
from torch.amp import autocast
with autocast(device_type="cuda", dtype=model._amp_dtype, enabled=True):
belief = model.sft.encode_observation(inputs)
tta_mean, tta_logvar = model.sft.tta_head(belief)
tta_var = torch.exp(tta_logvar.float().clamp(-20.0, 20.0))
all_beliefs.append(belief.float().cpu())
all_tta_means.append(tta_mean.float().cpu())
all_tta_vars.append(tta_var.cpu())
beliefs = torch.cat(all_beliefs, dim=0)
tta_means = torch.cat(all_tta_means, dim=0)
tta_vars = torch.cat(all_tta_vars, dim=0)
logger.info(
f" {split_name}: cached {beliefs.shape[0]} samples "
f"belief shape={tuple(beliefs.shape)} "
f"size={beliefs.nbytes / 1e6:.1f} MB"
)
return {"beliefs": beliefs, "tta_means": tta_means, "tta_vars": tta_vars}
def main():
parser = argparse.ArgumentParser("make_belief_cache")
parser.add_argument("--sft_checkpoint", required=True)
parser.add_argument("--label_dir", default="data/policy_labels")
parser.add_argument("--out_dir", default="data/belief_cache")
parser.add_argument("--batch_size", type=int, default=8,
help="Larger = faster caching (no grad, more GPU memory)")
parser.add_argument("--splits", nargs="+", default=["train", "val"])
args = parser.parse_args()
odir = Path(args.out_dir)
odir.mkdir(parents=True, exist_ok=True)
logger.info("Loading SFTModel (frozen backbone for belief extraction)...")
model = PolicyModel(args.sft_checkpoint, use_bf16=True)
for split in args.splits:
label_path = Path(args.label_dir) / f"{split}.json"
if not label_path.exists():
logger.warning(f" {label_path} not found — skipping {split}")
continue
out_path = odir / f"{split}.pt"
if out_path.exists():
logger.info(f" Cache already exists: {out_path} — skipping")
continue
logger.info(f"\nBuilding cache for split: {split}")
ds = PolicyDataset([label_path], split=split)
loader = DataLoader(
ds, batch_size=args.batch_size, shuffle=False,
num_workers=4, collate_fn=policy_collate_fn, pin_memory=True,
)
cache = build_cache(model, loader, split)
torch.save(cache, out_path)
logger.info(f" Saved → {out_path}")
logger.info("\n✅ Belief cache complete.")
if __name__ == "__main__":
main()