#!/usr/bin/env python3 """ Pre-compute and cache belief vectors for all policy label windows. Since the SFT backbone is fully frozen, belief[i] = SFTModel(frames[i]) is deterministic. Computing it once and saving it eliminates the 3B-param VLM forward pass from every training step, making PolicyHead training ~1000× faster. Output: data/belief_cache/train.pt — tensors for all train samples data/belief_cache/val.pt — tensors for all val samples Cache format (per split): { "beliefs": FloatTensor [N, hidden_dim] (float32) "tta_means": FloatTensor [N] "tta_vars": FloatTensor [N] } Indices match exactly the sample order in data/policy_labels/{split}.json. Usage: cd PROJECT_ROOT python -m training.Policy.make_belief_cache \ --sft_checkpoint checkpoints/SFT/sft_v2/best \ --label_dir data/policy_labels \ --out_dir data/belief_cache \ --batch_size 8 """ from __future__ import annotations import argparse import json import logging from pathlib import Path from typing import List import torch from torch.utils.data import DataLoader from tqdm import tqdm import sys sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) from .policy_model import PolicyModel from .policy_dataset import PolicyDataset, policy_collate_fn logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("Policy.make_cache") @torch.no_grad() def build_cache( model: PolicyModel, loader: DataLoader, split_name: str, ) -> dict: """Run VLM on all samples, collect belief + tta statistics.""" model.eval() all_beliefs: List[torch.Tensor] = [] all_tta_means: List[torch.Tensor] = [] all_tta_vars: List[torch.Tensor] = [] for batch in tqdm(loader, desc=f" Caching {split_name}"): inputs = model._build_inputs(batch["images"], batch["metadata"]) from torch.amp import autocast with autocast(device_type="cuda", dtype=model._amp_dtype, enabled=True): belief = model.sft.encode_observation(inputs) tta_mean, tta_logvar = model.sft.tta_head(belief) tta_var = torch.exp(tta_logvar.float().clamp(-20.0, 20.0)) all_beliefs.append(belief.float().cpu()) all_tta_means.append(tta_mean.float().cpu()) all_tta_vars.append(tta_var.cpu()) beliefs = torch.cat(all_beliefs, dim=0) tta_means = torch.cat(all_tta_means, dim=0) tta_vars = torch.cat(all_tta_vars, dim=0) logger.info( f" {split_name}: cached {beliefs.shape[0]} samples " f"belief shape={tuple(beliefs.shape)} " f"size={beliefs.nbytes / 1e6:.1f} MB" ) return {"beliefs": beliefs, "tta_means": tta_means, "tta_vars": tta_vars} def main(): parser = argparse.ArgumentParser("make_belief_cache") parser.add_argument("--sft_checkpoint", required=True) parser.add_argument("--label_dir", default="data/policy_labels") parser.add_argument("--out_dir", default="data/belief_cache") parser.add_argument("--batch_size", type=int, default=8, help="Larger = faster caching (no grad, more GPU memory)") parser.add_argument("--splits", nargs="+", default=["train", "val"]) args = parser.parse_args() odir = Path(args.out_dir) odir.mkdir(parents=True, exist_ok=True) logger.info("Loading SFTModel (frozen backbone for belief extraction)...") model = PolicyModel(args.sft_checkpoint, use_bf16=True) for split in args.splits: label_path = Path(args.label_dir) / f"{split}.json" if not label_path.exists(): logger.warning(f" {label_path} not found — skipping {split}") continue out_path = odir / f"{split}.pt" if out_path.exists(): logger.info(f" Cache already exists: {out_path} — skipping") continue logger.info(f"\nBuilding cache for split: {split}") ds = PolicyDataset([label_path], split=split) loader = DataLoader( ds, batch_size=args.batch_size, shuffle=False, num_workers=4, collate_fn=policy_collate_fn, pin_memory=True, ) cache = build_cache(model, loader, split) torch.save(cache, out_path) logger.info(f" Saved → {out_path}") logger.info("\n✅ Belief cache complete.") if __name__ == "__main__": main()