| """Fast version of make_cache_x_v2.py β reuses video frame buffer across ticks. |
| |
| Bottleneck of the original: for rolling per-tick manifests (e.g. 17 ticks per |
| video on CARLA), every tick calls `sample_frames_from_mp4_by_indices`, which |
| opens the video fresh and reads from frame 0 sequentially until it reaches the |
| wanted indices. For 17 ticks this decodes the same video ~17 times. |
| |
| Fix: monkey-patch `sample_frames_from_mp4_by_indices` to keep an LRU-1 cache of |
| the most recently decoded video's full frame list (already resized). Sort the |
| manifest by video_path so consecutive ticks of the same video hit the cache. |
| |
| Expected speed-up on CARLA rolling (17 ticks/clip avg): 5-10x for the decode |
| portion, bringing aggregate throughput close to the GPU-forward-bound limit. |
| """ |
| from __future__ import annotations |
|
|
| import sys |
| from pathlib import Path |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| sys.path.insert(0, str(ROOT)) |
|
|
| |
| from tools import run_train_cot_belief_fast |
|
|
| import argparse |
| import json |
| import logging |
| import time |
| from typing import Dict, List |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| |
| from training.VLA import frame_utils as _fu |
|
|
| _video_cache: Dict[str, List[Image.Image]] = {} |
| _cache_path: str = "" |
|
|
|
|
| def _resize_bgr(frame: np.ndarray, resize_short: int) -> Image.Image: |
| h, w = frame.shape[:2] |
| scale = resize_short / min(h, w) |
| nh, nw = int(round(h * scale)), int(round(w * scale)) |
| frame = cv2.resize(frame, (nw, nh), interpolation=cv2.INTER_AREA) |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| return Image.fromarray(frame) |
|
|
|
|
| def _decode_full_video(video_path: str, resize_short: int) -> List[Image.Image]: |
| """Decode every frame of a video once, return resized PIL RGB list.""" |
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| raise RuntimeError(f"could not open: {video_path}") |
| frames: List[Image.Image] = [] |
| while True: |
| ok, frame = cap.read() |
| if not ok: break |
| frames.append(_resize_bgr(frame, resize_short)) |
| cap.release() |
| return frames |
|
|
|
|
| def _patched_sample_by_indices(video_path, indices: List[int], |
| resize_short: int = 336, |
| return_times: bool = False): |
| """LRU-1 cached version: each video is decoded exactly once.""" |
| global _cache_path, _video_cache |
| vp = str(video_path) |
| if vp != _cache_path: |
| |
| _video_cache.clear() |
| _video_cache[vp] = _decode_full_video(vp, resize_short) |
| _cache_path = vp |
| all_frames = _video_cache[vp] |
| n_total = len(all_frames) |
| if n_total <= 0: |
| raise RuntimeError(f"bad video (0 frames decoded): {video_path}") |
| clipped = [max(0, min(n_total - 1, int(i))) for i in indices] |
| frames = [all_frames[i] for i in clipped] |
| if return_times: |
| |
| cap = cv2.VideoCapture(vp) |
| fps = float(cap.get(cv2.CAP_PROP_FPS)) or 30.0 |
| cap.release() |
| return frames, [i / fps for i in clipped] |
| return frames |
|
|
|
|
| |
| _fu.sample_frames_from_mp4_by_indices = _patched_sample_by_indices |
|
|
|
|
| |
| from tools.make_cache_x_v2 import ( |
| build_extraction_assistant, |
| extract_split, |
| ) |
|
|
|
|
| logging.basicConfig(level=logging.INFO, |
| format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger("make_cache_x_v2_fast") |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--manifest", type=Path, required=True) |
| ap.add_argument("--tag", default="sft_x_v2") |
| ap.add_argument("--split", required=True) |
| ap.add_argument("--out_dir", type=Path, |
| default=ROOT / "data/belief_cache_v2") |
| ap.add_argument("--ckpt", type=Path, |
| default=ROOT / "checkpoints/sft_x_v2/best") |
| ap.add_argument("--base_model", type=Path, |
| default=ROOT / "models/Qwen3-VL-4B-Instruct") |
| ap.add_argument("--belief_layers", nargs="+", type=int, |
| default=[20, 24, 28, 32]) |
| ap.add_argument("--policy_layer", type=int, default=33) |
| ap.add_argument("--batch_size", type=int, default=4) |
| ap.add_argument("--limit", type=int, default=0) |
| ap.add_argument("--pool_mode", |
| choices=["range", "open", "token_mean", "random_span"], |
| default="range") |
| ap.add_argument("--random_span_seed", type=int, default=0) |
| args = ap.parse_args() |
|
|
| |
| sorted_manifest = args.out_dir / f"_sorted__{args.manifest.name}" |
| args.out_dir.mkdir(parents=True, exist_ok=True) |
| n_records = 0 |
| with open(args.manifest) as fin: |
| records = [] |
| for ln in fin: |
| ln = ln.strip() |
| if not ln: continue |
| try: |
| r = json.loads(ln) |
| except json.JSONDecodeError: |
| continue |
| records.append(r) |
| n_records += 1 |
|
|
| def key(r): |
| return (r.get("video_path", ""), int(r.get("meta", {}).get("tick_index", 0))) |
|
|
| records.sort(key=key) |
| logger.info(f"[sort] {n_records} records sorted by (video_path, tick_index)") |
|
|
| with open(sorted_manifest, "w") as fout: |
| for r in records: |
| fout.write(json.dumps(r) + "\n") |
| logger.info(f"[save] sorted manifest β {sorted_manifest}") |
|
|
| out_path = args.out_dir / f"{args.tag}__{args.split}.pt" |
| extract_split( |
| ckpt_dir=args.ckpt, |
| base_model=args.base_model, |
| manifest_path=sorted_manifest, |
| out_path=out_path, |
| belief_layers=tuple(args.belief_layers), |
| policy_layer=args.policy_layer, |
| batch_size=args.batch_size, |
| limit=args.limit, |
| n_frames=8, |
| pool_mode=args.pool_mode, |
| random_span_seed=args.random_span_seed, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|