"""Fast version of make_cache_x_v2.py — reuses video frame buffer across ticks. Bottleneck of the original: for rolling per-tick manifests (e.g. 17 ticks per video on CARLA), every tick calls `sample_frames_from_mp4_by_indices`, which opens the video fresh and reads from frame 0 sequentially until it reaches the wanted indices. For 17 ticks this decodes the same video ~17 times. Fix: monkey-patch `sample_frames_from_mp4_by_indices` to keep an LRU-1 cache of the most recently decoded video's full frame list (already resized). Sort the manifest by video_path so consecutive ticks of the same video hit the cache. Expected speed-up on CARLA rolling (17 ticks/clip avg): 5-10x for the decode portion, bringing aggregate throughput close to the GPU-forward-bound limit. """ from __future__ import annotations import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(ROOT)) # Order matters: apply PR fast_patch BEFORE importing model code. from tools import run_train_cot_belief_fast # noqa: F401, E402 import argparse # noqa: E402 import json # noqa: E402 import logging # noqa: E402 import time # noqa: E402 from typing import Dict, List # noqa: E402 import cv2 # noqa: E402 import numpy as np # noqa: E402 import torch # noqa: E402 from PIL import Image # noqa: E402 from tqdm import tqdm # noqa: E402 # ── monkey-patch sample_frames with video-level cache ──────────────────── from training.VLA import frame_utils as _fu # noqa: E402 _video_cache: Dict[str, List[Image.Image]] = {} _cache_path: str = "" def _resize_bgr(frame: np.ndarray, resize_short: int) -> Image.Image: h, w = frame.shape[:2] scale = resize_short / min(h, w) nh, nw = int(round(h * scale)), int(round(w * scale)) frame = cv2.resize(frame, (nw, nh), interpolation=cv2.INTER_AREA) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) return Image.fromarray(frame) def _decode_full_video(video_path: str, resize_short: int) -> List[Image.Image]: """Decode every frame of a video once, return resized PIL RGB list.""" cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise RuntimeError(f"could not open: {video_path}") frames: List[Image.Image] = [] while True: ok, frame = cap.read() if not ok: break frames.append(_resize_bgr(frame, resize_short)) cap.release() return frames def _patched_sample_by_indices(video_path, indices: List[int], resize_short: int = 336, return_times: bool = False): """LRU-1 cached version: each video is decoded exactly once.""" global _cache_path, _video_cache vp = str(video_path) if vp != _cache_path: # Evict previous video to free memory _video_cache.clear() _video_cache[vp] = _decode_full_video(vp, resize_short) _cache_path = vp all_frames = _video_cache[vp] n_total = len(all_frames) if n_total <= 0: raise RuntimeError(f"bad video (0 frames decoded): {video_path}") clipped = [max(0, min(n_total - 1, int(i))) for i in indices] frames = [all_frames[i] for i in clipped] if return_times: # Caller doesn't have fps anymore; approximate via cv2 cap = cv2.VideoCapture(vp) fps = float(cap.get(cv2.CAP_PROP_FPS)) or 30.0 cap.release() return frames, [i / fps for i in clipped] return frames # Apply monkey-patch _fu.sample_frames_from_mp4_by_indices = _patched_sample_by_indices # ── now import the original extraction code with patches active ────────── from tools.make_cache_x_v2 import ( # noqa: E402 build_extraction_assistant, extract_split, ) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("make_cache_x_v2_fast") def main(): ap = argparse.ArgumentParser() ap.add_argument("--manifest", type=Path, required=True) ap.add_argument("--tag", default="sft_x_v2") ap.add_argument("--split", required=True) ap.add_argument("--out_dir", type=Path, default=ROOT / "data/belief_cache_v2") ap.add_argument("--ckpt", type=Path, default=ROOT / "checkpoints/sft_x_v2/best") ap.add_argument("--base_model", type=Path, default=ROOT / "models/Qwen3-VL-4B-Instruct") ap.add_argument("--belief_layers", nargs="+", type=int, default=[20, 24, 28, 32]) ap.add_argument("--policy_layer", type=int, default=33) ap.add_argument("--batch_size", type=int, default=4) ap.add_argument("--limit", type=int, default=0) ap.add_argument("--pool_mode", choices=["range", "open", "token_mean", "random_span"], default="range") ap.add_argument("--random_span_seed", type=int, default=0) args = ap.parse_args() # ── Pre-sort manifest by video_id so consecutive batches hit the cache ── sorted_manifest = args.out_dir / f"_sorted__{args.manifest.name}" args.out_dir.mkdir(parents=True, exist_ok=True) n_records = 0 with open(args.manifest) as fin: records = [] for ln in fin: ln = ln.strip() if not ln: continue try: r = json.loads(ln) except json.JSONDecodeError: continue records.append(r) n_records += 1 def key(r): return (r.get("video_path", ""), int(r.get("meta", {}).get("tick_index", 0))) records.sort(key=key) logger.info(f"[sort] {n_records} records sorted by (video_path, tick_index)") with open(sorted_manifest, "w") as fout: for r in records: fout.write(json.dumps(r) + "\n") logger.info(f"[save] sorted manifest → {sorted_manifest}") out_path = args.out_dir / f"{args.tag}__{args.split}.pt" extract_split( ckpt_dir=args.ckpt, base_model=args.base_model, manifest_path=sorted_manifest, out_path=out_path, belief_layers=tuple(args.belief_layers), policy_layer=args.policy_layer, batch_size=args.batch_size, limit=args.limit, n_frames=8, pool_mode=args.pool_mode, random_span_seed=args.random_span_seed, ) if __name__ == "__main__": main()