| |
| """ |
| Pre-compute and cache belief vectors for all policy label windows. |
| |
| Since the SFT backbone is fully frozen, belief[i] = SFTModel(frames[i]) is |
| deterministic. Computing it once and saving it eliminates the 3B-param VLM |
| forward pass from every training step, making PolicyHead training ~1000× faster. |
| |
| Output: |
| data/belief_cache/train.pt — tensors for all train samples |
| data/belief_cache/val.pt — tensors for all val samples |
| |
| Cache format (per split): |
| { |
| "beliefs": FloatTensor [N, hidden_dim] (float32) |
| "tta_means": FloatTensor [N] |
| "tta_vars": FloatTensor [N] |
| } |
| Indices match exactly the sample order in data/policy_labels/{split}.json. |
| |
| Usage: |
| cd PROJECT_ROOT |
| python -m training.Policy.make_belief_cache \ |
| --sft_checkpoint checkpoints/SFT/sft_v2/best \ |
| --label_dir data/policy_labels \ |
| --out_dir data/belief_cache \ |
| --batch_size 8 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| from pathlib import Path |
| from typing import List |
|
|
| import torch |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| import sys |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) |
|
|
| from .policy_model import PolicyModel |
| from .policy_dataset import PolicyDataset, policy_collate_fn |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger("Policy.make_cache") |
|
|
|
|
| @torch.no_grad() |
| def build_cache( |
| model: PolicyModel, |
| loader: DataLoader, |
| split_name: str, |
| ) -> dict: |
| """Run VLM on all samples, collect belief + tta statistics.""" |
| model.eval() |
|
|
| all_beliefs: List[torch.Tensor] = [] |
| all_tta_means: List[torch.Tensor] = [] |
| all_tta_vars: List[torch.Tensor] = [] |
|
|
| for batch in tqdm(loader, desc=f" Caching {split_name}"): |
| inputs = model._build_inputs(batch["images"], batch["metadata"]) |
|
|
| from torch.amp import autocast |
| with autocast(device_type="cuda", dtype=model._amp_dtype, enabled=True): |
| belief = model.sft.encode_observation(inputs) |
| tta_mean, tta_logvar = model.sft.tta_head(belief) |
|
|
| tta_var = torch.exp(tta_logvar.float().clamp(-20.0, 20.0)) |
|
|
| all_beliefs.append(belief.float().cpu()) |
| all_tta_means.append(tta_mean.float().cpu()) |
| all_tta_vars.append(tta_var.cpu()) |
|
|
| beliefs = torch.cat(all_beliefs, dim=0) |
| tta_means = torch.cat(all_tta_means, dim=0) |
| tta_vars = torch.cat(all_tta_vars, dim=0) |
|
|
| logger.info( |
| f" {split_name}: cached {beliefs.shape[0]} samples " |
| f"belief shape={tuple(beliefs.shape)} " |
| f"size={beliefs.nbytes / 1e6:.1f} MB" |
| ) |
| return {"beliefs": beliefs, "tta_means": tta_means, "tta_vars": tta_vars} |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser("make_belief_cache") |
| parser.add_argument("--sft_checkpoint", required=True) |
| parser.add_argument("--label_dir", default="data/policy_labels") |
| parser.add_argument("--out_dir", default="data/belief_cache") |
| parser.add_argument("--batch_size", type=int, default=8, |
| help="Larger = faster caching (no grad, more GPU memory)") |
| parser.add_argument("--splits", nargs="+", default=["train", "val"]) |
| args = parser.parse_args() |
|
|
| odir = Path(args.out_dir) |
| odir.mkdir(parents=True, exist_ok=True) |
|
|
| logger.info("Loading SFTModel (frozen backbone for belief extraction)...") |
| model = PolicyModel(args.sft_checkpoint, use_bf16=True) |
|
|
| for split in args.splits: |
| label_path = Path(args.label_dir) / f"{split}.json" |
| if not label_path.exists(): |
| logger.warning(f" {label_path} not found — skipping {split}") |
| continue |
|
|
| out_path = odir / f"{split}.pt" |
| if out_path.exists(): |
| logger.info(f" Cache already exists: {out_path} — skipping") |
| continue |
|
|
| logger.info(f"\nBuilding cache for split: {split}") |
| ds = PolicyDataset([label_path], split=split) |
| loader = DataLoader( |
| ds, batch_size=args.batch_size, shuffle=False, |
| num_workers=4, collate_fn=policy_collate_fn, pin_memory=True, |
| ) |
|
|
| cache = build_cache(model, loader, split) |
| torch.save(cache, out_path) |
| logger.info(f" Saved → {out_path}") |
|
|
| logger.info("\n✅ Belief cache complete.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|