File size: 6,455 Bytes
1e05592 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | """Fast version of make_cache_x_v2.py β reuses video frame buffer across ticks.
Bottleneck of the original: for rolling per-tick manifests (e.g. 17 ticks per
video on CARLA), every tick calls `sample_frames_from_mp4_by_indices`, which
opens the video fresh and reads from frame 0 sequentially until it reaches the
wanted indices. For 17 ticks this decodes the same video ~17 times.
Fix: monkey-patch `sample_frames_from_mp4_by_indices` to keep an LRU-1 cache of
the most recently decoded video's full frame list (already resized). Sort the
manifest by video_path so consecutive ticks of the same video hit the cache.
Expected speed-up on CARLA rolling (17 ticks/clip avg): 5-10x for the decode
portion, bringing aggregate throughput close to the GPU-forward-bound limit.
"""
from __future__ import annotations
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
# Order matters: apply PR fast_patch BEFORE importing model code.
from tools import run_train_cot_belief_fast # noqa: F401, E402
import argparse # noqa: E402
import json # noqa: E402
import logging # noqa: E402
import time # noqa: E402
from typing import Dict, List # noqa: E402
import cv2 # noqa: E402
import numpy as np # noqa: E402
import torch # noqa: E402
from PIL import Image # noqa: E402
from tqdm import tqdm # noqa: E402
# ββ monkey-patch sample_frames with video-level cache ββββββββββββββββββββ
from training.VLA import frame_utils as _fu # noqa: E402
_video_cache: Dict[str, List[Image.Image]] = {}
_cache_path: str = ""
def _resize_bgr(frame: np.ndarray, resize_short: int) -> Image.Image:
h, w = frame.shape[:2]
scale = resize_short / min(h, w)
nh, nw = int(round(h * scale)), int(round(w * scale))
frame = cv2.resize(frame, (nw, nh), interpolation=cv2.INTER_AREA)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return Image.fromarray(frame)
def _decode_full_video(video_path: str, resize_short: int) -> List[Image.Image]:
"""Decode every frame of a video once, return resized PIL RGB list."""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"could not open: {video_path}")
frames: List[Image.Image] = []
while True:
ok, frame = cap.read()
if not ok: break
frames.append(_resize_bgr(frame, resize_short))
cap.release()
return frames
def _patched_sample_by_indices(video_path, indices: List[int],
resize_short: int = 336,
return_times: bool = False):
"""LRU-1 cached version: each video is decoded exactly once."""
global _cache_path, _video_cache
vp = str(video_path)
if vp != _cache_path:
# Evict previous video to free memory
_video_cache.clear()
_video_cache[vp] = _decode_full_video(vp, resize_short)
_cache_path = vp
all_frames = _video_cache[vp]
n_total = len(all_frames)
if n_total <= 0:
raise RuntimeError(f"bad video (0 frames decoded): {video_path}")
clipped = [max(0, min(n_total - 1, int(i))) for i in indices]
frames = [all_frames[i] for i in clipped]
if return_times:
# Caller doesn't have fps anymore; approximate via cv2
cap = cv2.VideoCapture(vp)
fps = float(cap.get(cv2.CAP_PROP_FPS)) or 30.0
cap.release()
return frames, [i / fps for i in clipped]
return frames
# Apply monkey-patch
_fu.sample_frames_from_mp4_by_indices = _patched_sample_by_indices
# ββ now import the original extraction code with patches active ββββββββββ
from tools.make_cache_x_v2 import ( # noqa: E402
build_extraction_assistant,
extract_split,
)
logging.basicConfig(level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("make_cache_x_v2_fast")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--manifest", type=Path, required=True)
ap.add_argument("--tag", default="sft_x_v2")
ap.add_argument("--split", required=True)
ap.add_argument("--out_dir", type=Path,
default=ROOT / "data/belief_cache_v2")
ap.add_argument("--ckpt", type=Path,
default=ROOT / "checkpoints/sft_x_v2/best")
ap.add_argument("--base_model", type=Path,
default=ROOT / "models/Qwen3-VL-4B-Instruct")
ap.add_argument("--belief_layers", nargs="+", type=int,
default=[20, 24, 28, 32])
ap.add_argument("--policy_layer", type=int, default=33)
ap.add_argument("--batch_size", type=int, default=4)
ap.add_argument("--limit", type=int, default=0)
ap.add_argument("--pool_mode",
choices=["range", "open", "token_mean", "random_span"],
default="range")
ap.add_argument("--random_span_seed", type=int, default=0)
args = ap.parse_args()
# ββ Pre-sort manifest by video_id so consecutive batches hit the cache ββ
sorted_manifest = args.out_dir / f"_sorted__{args.manifest.name}"
args.out_dir.mkdir(parents=True, exist_ok=True)
n_records = 0
with open(args.manifest) as fin:
records = []
for ln in fin:
ln = ln.strip()
if not ln: continue
try:
r = json.loads(ln)
except json.JSONDecodeError:
continue
records.append(r)
n_records += 1
def key(r):
return (r.get("video_path", ""), int(r.get("meta", {}).get("tick_index", 0)))
records.sort(key=key)
logger.info(f"[sort] {n_records} records sorted by (video_path, tick_index)")
with open(sorted_manifest, "w") as fout:
for r in records:
fout.write(json.dumps(r) + "\n")
logger.info(f"[save] sorted manifest β {sorted_manifest}")
out_path = args.out_dir / f"{args.tag}__{args.split}.pt"
extract_split(
ckpt_dir=args.ckpt,
base_model=args.base_model,
manifest_path=sorted_manifest,
out_path=out_path,
belief_layers=tuple(args.belief_layers),
policy_layer=args.policy_layer,
batch_size=args.batch_size,
limit=args.limit,
n_frames=8,
pool_mode=args.pool_mode,
random_span_seed=args.random_span_seed,
)
if __name__ == "__main__":
main()
|