anwm / batch_make_compare.py
de99's picture
Upload batch_make_compare.py
cff9ece verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
批量调用 make_compare.py 为多个ID生成对比视频
"""
import subprocess
import sys
import os
import argparse
import tempfile
import shutil
from pathlib import Path
def read_id_list(id_file: str) -> list:
"""读取ID列表文件,返回ID数字列表(用于make_compare.py)"""
ids = []
with open(id_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#'):
# 处理 id_xxx 格式,提取数字部分
if line.startswith('id_'):
id_num = line.replace('id_', '')
try:
ids.append(int(id_num))
except ValueError:
print(f"[WARN] 无法解析ID: {line},跳过")
else:
# 如果只是数字,直接使用
try:
ids.append(int(line))
except ValueError:
print(f"[WARN] 无法解析ID: {line},跳过")
return ids
def find_max_frame_in_dir(base_dir: str, id_val: int) -> int:
"""查找目录中最大的帧索引"""
id_dir = os.path.join(base_dir, f"id_{id_val}")
if not os.path.isdir(id_dir):
return -1
max_frame = -1
for fname in os.listdir(id_dir):
if fname.endswith('.png'):
try:
frame_num = int(fname.replace('.png', ''))
max_frame = max(max_frame, frame_num)
except ValueError:
pass
return max_frame
def run_make_compare_combined(id_val: int, args_dict: dict, dry_run: bool = False):
"""运行 make_compare.py 为单个ID生成合并视频(包含4fps和1fps)"""
# 只使用 4fps 的 GT,合并 4fps 和 1fps 的 pred
gt_dir = args_dict['gt'] # 只用 4fps 的 GT
pred_dirs = args_dict['pred'].copy() # 4fps 的 pred
# 如果有 1fps 数据,只添加 pred(不添加 GT)
if args_dict.get('pred_1fps'):
pred_dirs.extend(args_dict['pred_1fps'])
# 合并所有列:GT(4fps) + Pred(4fps) + Pred(1fps)
all_cols = [gt_dir] + pred_dirs
# 为 1fps 列创建缺失帧的符号链接(映射到对应的 1fps 帧)
# 1fps 列每4帧更新一次,所以帧 0,1,2,3 -> 显示帧0,帧 4,5,6,7 -> 显示帧4,等等
end_frame = args_dict['end']
if args_dict.get('pred_1fps'):
import tempfile
import shutil
# 为每个 1fps 列创建临时目录,包含所有需要的帧(通过符号链接)
temp_dirs = []
for pred_1fps_dir in args_dict['pred_1fps']:
# 创建临时目录
temp_dir = tempfile.mkdtemp(prefix=f'temp_1fps_{id_val}_')
temp_dirs.append(temp_dir)
# 复制 id 目录结构
src_id_dir = os.path.join(pred_1fps_dir, f"id_{id_val}")
dst_id_dir = os.path.join(temp_dir, f"id_{id_val}")
os.makedirs(dst_id_dir, exist_ok=True)
# 先找到 1fps 列中实际存在的最大帧
max_available_frame = find_max_frame_in_dir(pred_1fps_dir, id_val)
# 为每一帧创建符号链接(映射到对应的 1fps 帧)
# 映射规则:4fps 的帧 fid -> 1fps 的帧 (fid // 4)
# 例如:4fps 帧 0,1,2,3 -> 1fps 帧 0
# 4fps 帧 4,5,6,7 -> 1fps 帧 1
# 4fps 帧 8,9,10,11 -> 1fps 帧 2
start_frame = args_dict.get('start') if args_dict.get('start') is not None else 0
for fid in range(start_frame, end_frame + 1):
# 映射规则:fid -> fid // 4(每4帧更新一次,1fps的1帧对应4fps的4帧)
mapped_fid = fid // 4
# 如果映射的帧超过了实际存在的最大帧,使用最大帧
if mapped_fid > max_available_frame:
mapped_fid = max_available_frame
src_frame = os.path.join(src_id_dir, f"{mapped_fid}.png")
dst_frame = os.path.join(dst_id_dir, f"{fid}.png")
if os.path.isfile(src_frame):
# 创建符号链接
if os.path.exists(dst_frame):
os.remove(dst_frame)
os.symlink(src_frame, dst_frame)
else:
# 如果源帧不存在,尝试找最近的帧(向下查找)
found = False
for check_fid in range(mapped_fid, -1, -1):
check_src = os.path.join(src_id_dir, f"{check_fid}.png")
if os.path.isfile(check_src):
if os.path.exists(dst_frame):
os.remove(dst_frame)
os.symlink(check_src, dst_frame)
found = True
break
if not found:
print(f"[WARN] id_{id_val} 1fps 列找不到任何帧")
# 更新 pred_dirs 中的路径,使用临时目录
pred_dirs[pred_dirs.index(pred_1fps_dir)] = temp_dir
# 保存临时目录列表,以便后续清理
args_dict['_temp_dirs'] = temp_dirs
# 使用用户提供的标签(如果提供了完整的标签列表)
if args_dict.get('labels_combined'):
labels = args_dict['labels_combined']
else:
# 否则自动生成标签
labels = args_dict['labels'].copy() # GT + 4fps pred
if args_dict.get('pred_1fps'):
# 为 1fps 的 pred 添加后缀
labels_1fps = [f"{label}_1fps" for label in args_dict['labels'][1:]] # 跳过 GT
labels = labels + labels_1fps
# 使用 4fps 的 fps 值(因为合并后视频统一使用一个 fps)
fps = args_dict.get('fps', 1)
output_name = f'compare_id{id_val}.mp4'
cmd = [
sys.executable,
'make_compare.py',
'--gt', gt_dir, # 只用 4fps 的 GT
'--pred'] + pred_dirs + [ # 4fps pred + 1fps pred
'--id', str(id_val),
'--out_dir', args_dict['out_dir'],
'--end', str(end_frame),
'--labels'] + labels + [
'--fps', str(int(fps)) if fps == int(fps) else str(fps),
'--out', output_name,
]
if args_dict.get('safe_even', False):
cmd.append('--safe_even')
if args_dict.get('start', None):
cmd.extend(['--start', str(args_dict['start'])])
print(f"\n{'='*80}")
print(f"处理 ID: id_{id_val} (合并 4fps + 1fps, fps={fps})")
print(f"列数: {len(all_cols)} (GT_4fps + {len(args_dict['pred'])} pred_4fps + {len(args_dict.get('pred_1fps', []))} pred_1fps)")
print(f"命令: {' '.join(cmd)}")
print(f"{'='*80}")
if dry_run:
print("[DRY RUN] 不会实际执行")
return True
try:
result = subprocess.run(cmd, check=True, cwd=args_dict.get('work_dir', '.'))
print(f"[OK] id_{id_val} (合并视频) 完成")
# 清理临时目录
if args_dict.get('_temp_dirs'):
for temp_dir in args_dict['_temp_dirs']:
try:
shutil.rmtree(temp_dir)
except Exception as e:
print(f"[WARN] 清理临时目录失败 {temp_dir}: {e}")
return True
except subprocess.CalledProcessError as e:
print(f"[ERROR] id_{id_val} (合并视频) 失败: {e}", file=sys.stderr)
# 清理临时目录
if args_dict.get('_temp_dirs'):
for temp_dir in args_dict['_temp_dirs']:
try:
shutil.rmtree(temp_dir)
except Exception:
pass
return False
except Exception as e:
print(f"[ERROR] id_{id_val} (合并视频) 出错: {e}", file=sys.stderr)
# 清理临时目录
if args_dict.get('_temp_dirs'):
for temp_dir in args_dict['_temp_dirs']:
try:
shutil.rmtree(temp_dir)
except Exception:
pass
return False
def main():
parser = argparse.ArgumentParser(description='批量生成对比视频')
parser.add_argument('--id_file', type=str, required=True,
help='ID列表文件路径')
parser.add_argument('--gt', type=str, required=True,
help='GT rollout目录')
parser.add_argument('--pred', nargs='+', required=True,
help='预测 rollout 目录(一个或多个)')
parser.add_argument('--out_dir', type=str, required=True,
help='输出目录')
parser.add_argument('--labels', nargs='+', required=True,
help='标签列表(与GT+pred数量对应)')
parser.add_argument('--end', type=int, default=51,
help='结束帧索引')
parser.add_argument('--start', type=int, default=None,
help='开始帧索引(可选)')
parser.add_argument('--fps', type=float, default=1,
help='视频FPS(用于rollout_4fps,默认1)')
parser.add_argument('--gt_1fps', type=str, default=None,
help='GT rollout_1fps目录(已废弃,不再使用)')
parser.add_argument('--pred_1fps', nargs='+', default=None,
help='预测 rollout_1fps 目录(一个或多个,将合并到视频中,但不包含GT)')
parser.add_argument('--fps_1fps', type=float, default=0.25,
help='视频FPS(用于rollout_1fps,默认0.25)')
parser.add_argument('--safe_even', action='store_true',
help='使用安全偶数尺寸')
parser.add_argument('--work_dir', type=str, default=None,
help='工作目录(默认当前目录)')
parser.add_argument('--dry_run', action='store_true',
help='仅显示命令,不实际执行')
parser.add_argument('--skip_existing', action='store_true',
help='跳过已存在的视频文件')
parser.add_argument('--max_workers', type=int, default=1,
help='最大并行数(默认1,串行执行)')
args = parser.parse_args()
# 读取ID列表
ids = read_id_list(args.id_file)
print(f"[INFO] 从 {args.id_file} 读取到 {len(ids)} 个ID")
if not ids:
print("[ERROR] 没有找到有效的ID", file=sys.stderr)
sys.exit(1)
# 确定处理模式:合并模式(如果指定了 1fps pred 数据)
combine_mode = args.pred_1fps is not None
if combine_mode:
if len(args.pred_1fps) != len(args.pred):
print(f"[ERROR] rollout_1fps 的预测目录数量 ({len(args.pred_1fps)}) 与 rollout_4fps ({len(args.pred)}) 不匹配", file=sys.stderr)
sys.exit(1)
# 准备参数
args_dict = {
'gt': args.gt,
'pred': args.pred,
'out_dir': args.out_dir,
'labels': args.labels,
'labels_combined': args.labels if combine_mode and len(args.labels) == (1 + len(args.pred) + len(args.pred_1fps)) else None,
'end': args.end,
'start': args.start,
'fps': args.fps,
'fps_1fps': args.fps_1fps,
'gt_1fps': args.gt_1fps,
'pred_1fps': args.pred_1fps,
'safe_even': args.safe_even,
'work_dir': args.work_dir or os.getcwd(),
'dry_run': args.dry_run,
}
# 验证标签数量
if combine_mode:
expected_labels = 1 + len(args.pred) + len(args.pred_1fps) # GT + 4fps pred + 1fps pred
if len(args.labels) != expected_labels:
print(f"[WARN] 标签数量 ({len(args.labels)}) 与预期 ({expected_labels}) 不匹配")
print(f" 预期: GT + {len(args.pred)} 个 4fps 预测方法 + {len(args.pred_1fps)} 个 1fps 预测方法")
else:
expected_labels = 1 + len(args.pred) # GT + pred数量
if len(args.labels) != expected_labels:
print(f"[WARN] 标签数量 ({len(args.labels)}) 与预期 ({expected_labels}) 不匹配")
print(f" 预期: GT + {len(args.pred)} 个预测方法")
# 检查输出目录
os.makedirs(args.out_dir, exist_ok=True)
# 执行批量处理
success_count = 0
skip_count = 0
fail_count = 0
total_tasks = len(ids)
current_task = 0
if combine_mode:
print(f"\n开始批量处理 {len(ids)} 个ID(合并模式:4fps + 1fps)...")
else:
print(f"\n开始批量处理 {len(ids)} 个ID(仅 4fps)...")
print(f"输出目录: {args.out_dir}")
print(f"并行数: {args.max_workers}")
print(f"跳过已存在: {args.skip_existing}")
print(f"总任务数: {total_tasks}\n")
for i, id_val in enumerate(ids, 1):
current_task += 1
# 检查是否已存在
if args.skip_existing:
video_path = os.path.join(args.out_dir, f"id_{id_val}", f"compare_id{id_val}.mp4")
if os.path.exists(video_path):
print(f"[SKIP] id_{id_val} 已存在,跳过 ({current_task}/{total_tasks})")
skip_count += 1
continue
print(f"\n[{current_task}/{total_tasks}] 处理 id_{id_val}...")
if combine_mode:
success = run_make_compare_combined(id_val, args_dict, args.dry_run)
else:
# 如果没有 1fps 数据,只生成 4fps 视频(简化版本)
cmd = [
sys.executable,
'make_compare.py',
'--gt', args_dict['gt'],
'--pred'] + args_dict['pred'] + [
'--id', str(id_val),
'--out_dir', args_dict['out_dir'],
'--end', str(args_dict['end']),
'--labels'] + args_dict['labels'] + [
'--fps', str(int(args_dict['fps'])) if args_dict['fps'] == int(args_dict['fps']) else str(args_dict['fps']),
]
if args_dict.get('safe_even', False):
cmd.append('--safe_even')
if args_dict.get('start', None):
cmd.extend(['--start', str(args_dict['start'])])
if args_dict.get('dry_run', False):
print(f"[DRY RUN] 命令: {' '.join(cmd)}")
success = True
else:
try:
result = subprocess.run(cmd, check=True, cwd=args_dict.get('work_dir', '.'))
success = True
except Exception as e:
print(f"[ERROR] id_{id_val} 失败: {e}", file=sys.stderr)
success = False
if success:
success_count += 1
else:
fail_count += 1
if not args.dry_run:
# 询问是否继续
response = input(f"\nid_{id_val} 失败,是否继续?(y/n): ").strip().lower()
if response != 'y':
print("用户取消")
return
# 总结
print(f"\n{'='*80}")
print("批量处理完成!")
if combine_mode:
print(f"总计: {len(ids)} 个ID(每个包含 4fps + 1fps 合并视频)")
else:
print(f"总计: {len(ids)} 个ID(仅 4fps)")
print(f"成功: {success_count} 个")
print(f"跳过: {skip_count} 个")
print(f"失败: {fail_count} 个")
print(f"{'='*80}")
if __name__ == '__main__':
main()