| |
| |
| """ |
| 批量调用 make_compare.py 为多个ID生成对比视频 |
| """ |
| import subprocess |
| import sys |
| import os |
| import argparse |
| import tempfile |
| import shutil |
| from pathlib import Path |
|
|
| def read_id_list(id_file: str) -> list: |
| """读取ID列表文件,返回ID数字列表(用于make_compare.py)""" |
| ids = [] |
| with open(id_file, 'r', encoding='utf-8') as f: |
| for line in f: |
| line = line.strip() |
| if line and not line.startswith('#'): |
| |
| if line.startswith('id_'): |
| id_num = line.replace('id_', '') |
| try: |
| ids.append(int(id_num)) |
| except ValueError: |
| print(f"[WARN] 无法解析ID: {line},跳过") |
| else: |
| |
| try: |
| ids.append(int(line)) |
| except ValueError: |
| print(f"[WARN] 无法解析ID: {line},跳过") |
| return ids |
|
|
| def find_max_frame_in_dir(base_dir: str, id_val: int) -> int: |
| """查找目录中最大的帧索引""" |
| id_dir = os.path.join(base_dir, f"id_{id_val}") |
| if not os.path.isdir(id_dir): |
| return -1 |
| max_frame = -1 |
| for fname in os.listdir(id_dir): |
| if fname.endswith('.png'): |
| try: |
| frame_num = int(fname.replace('.png', '')) |
| max_frame = max(max_frame, frame_num) |
| except ValueError: |
| pass |
| return max_frame |
|
|
| def run_make_compare_combined(id_val: int, args_dict: dict, dry_run: bool = False): |
| """运行 make_compare.py 为单个ID生成合并视频(包含4fps和1fps)""" |
| |
| gt_dir = args_dict['gt'] |
| pred_dirs = args_dict['pred'].copy() |
| |
| |
| if args_dict.get('pred_1fps'): |
| pred_dirs.extend(args_dict['pred_1fps']) |
| |
| |
| all_cols = [gt_dir] + pred_dirs |
| |
| |
| |
| end_frame = args_dict['end'] |
| if args_dict.get('pred_1fps'): |
| import tempfile |
| import shutil |
| |
| |
| temp_dirs = [] |
| for pred_1fps_dir in args_dict['pred_1fps']: |
| |
| temp_dir = tempfile.mkdtemp(prefix=f'temp_1fps_{id_val}_') |
| temp_dirs.append(temp_dir) |
| |
| |
| src_id_dir = os.path.join(pred_1fps_dir, f"id_{id_val}") |
| dst_id_dir = os.path.join(temp_dir, f"id_{id_val}") |
| os.makedirs(dst_id_dir, exist_ok=True) |
| |
| |
| max_available_frame = find_max_frame_in_dir(pred_1fps_dir, id_val) |
| |
| |
| |
| |
| |
| |
| start_frame = args_dict.get('start') if args_dict.get('start') is not None else 0 |
| for fid in range(start_frame, end_frame + 1): |
| |
| mapped_fid = fid // 4 |
| |
| |
| if mapped_fid > max_available_frame: |
| mapped_fid = max_available_frame |
| |
| src_frame = os.path.join(src_id_dir, f"{mapped_fid}.png") |
| dst_frame = os.path.join(dst_id_dir, f"{fid}.png") |
| |
| if os.path.isfile(src_frame): |
| |
| if os.path.exists(dst_frame): |
| os.remove(dst_frame) |
| os.symlink(src_frame, dst_frame) |
| else: |
| |
| found = False |
| for check_fid in range(mapped_fid, -1, -1): |
| check_src = os.path.join(src_id_dir, f"{check_fid}.png") |
| if os.path.isfile(check_src): |
| if os.path.exists(dst_frame): |
| os.remove(dst_frame) |
| os.symlink(check_src, dst_frame) |
| found = True |
| break |
| if not found: |
| print(f"[WARN] id_{id_val} 1fps 列找不到任何帧") |
| |
| |
| pred_dirs[pred_dirs.index(pred_1fps_dir)] = temp_dir |
| |
| |
| args_dict['_temp_dirs'] = temp_dirs |
| |
| |
| if args_dict.get('labels_combined'): |
| labels = args_dict['labels_combined'] |
| else: |
| |
| labels = args_dict['labels'].copy() |
| if args_dict.get('pred_1fps'): |
| |
| labels_1fps = [f"{label}_1fps" for label in args_dict['labels'][1:]] |
| labels = labels + labels_1fps |
| |
| |
| fps = args_dict.get('fps', 1) |
| output_name = f'compare_id{id_val}.mp4' |
| |
| cmd = [ |
| sys.executable, |
| 'make_compare.py', |
| '--gt', gt_dir, |
| '--pred'] + pred_dirs + [ |
| '--id', str(id_val), |
| '--out_dir', args_dict['out_dir'], |
| '--end', str(end_frame), |
| '--labels'] + labels + [ |
| '--fps', str(int(fps)) if fps == int(fps) else str(fps), |
| '--out', output_name, |
| ] |
| |
| if args_dict.get('safe_even', False): |
| cmd.append('--safe_even') |
| |
| if args_dict.get('start', None): |
| cmd.extend(['--start', str(args_dict['start'])]) |
| |
| print(f"\n{'='*80}") |
| print(f"处理 ID: id_{id_val} (合并 4fps + 1fps, fps={fps})") |
| print(f"列数: {len(all_cols)} (GT_4fps + {len(args_dict['pred'])} pred_4fps + {len(args_dict.get('pred_1fps', []))} pred_1fps)") |
| print(f"命令: {' '.join(cmd)}") |
| print(f"{'='*80}") |
| |
| if dry_run: |
| print("[DRY RUN] 不会实际执行") |
| return True |
| |
| try: |
| result = subprocess.run(cmd, check=True, cwd=args_dict.get('work_dir', '.')) |
| print(f"[OK] id_{id_val} (合并视频) 完成") |
| |
| |
| if args_dict.get('_temp_dirs'): |
| for temp_dir in args_dict['_temp_dirs']: |
| try: |
| shutil.rmtree(temp_dir) |
| except Exception as e: |
| print(f"[WARN] 清理临时目录失败 {temp_dir}: {e}") |
| |
| return True |
| except subprocess.CalledProcessError as e: |
| print(f"[ERROR] id_{id_val} (合并视频) 失败: {e}", file=sys.stderr) |
| |
| |
| if args_dict.get('_temp_dirs'): |
| for temp_dir in args_dict['_temp_dirs']: |
| try: |
| shutil.rmtree(temp_dir) |
| except Exception: |
| pass |
| |
| return False |
| except Exception as e: |
| print(f"[ERROR] id_{id_val} (合并视频) 出错: {e}", file=sys.stderr) |
| |
| |
| if args_dict.get('_temp_dirs'): |
| for temp_dir in args_dict['_temp_dirs']: |
| try: |
| shutil.rmtree(temp_dir) |
| except Exception: |
| pass |
| |
| return False |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='批量生成对比视频') |
| parser.add_argument('--id_file', type=str, required=True, |
| help='ID列表文件路径') |
| parser.add_argument('--gt', type=str, required=True, |
| help='GT rollout目录') |
| parser.add_argument('--pred', nargs='+', required=True, |
| help='预测 rollout 目录(一个或多个)') |
| parser.add_argument('--out_dir', type=str, required=True, |
| help='输出目录') |
| parser.add_argument('--labels', nargs='+', required=True, |
| help='标签列表(与GT+pred数量对应)') |
| parser.add_argument('--end', type=int, default=51, |
| help='结束帧索引') |
| parser.add_argument('--start', type=int, default=None, |
| help='开始帧索引(可选)') |
| parser.add_argument('--fps', type=float, default=1, |
| help='视频FPS(用于rollout_4fps,默认1)') |
| parser.add_argument('--gt_1fps', type=str, default=None, |
| help='GT rollout_1fps目录(已废弃,不再使用)') |
| parser.add_argument('--pred_1fps', nargs='+', default=None, |
| help='预测 rollout_1fps 目录(一个或多个,将合并到视频中,但不包含GT)') |
| parser.add_argument('--fps_1fps', type=float, default=0.25, |
| help='视频FPS(用于rollout_1fps,默认0.25)') |
| parser.add_argument('--safe_even', action='store_true', |
| help='使用安全偶数尺寸') |
| parser.add_argument('--work_dir', type=str, default=None, |
| help='工作目录(默认当前目录)') |
| parser.add_argument('--dry_run', action='store_true', |
| help='仅显示命令,不实际执行') |
| parser.add_argument('--skip_existing', action='store_true', |
| help='跳过已存在的视频文件') |
| parser.add_argument('--max_workers', type=int, default=1, |
| help='最大并行数(默认1,串行执行)') |
| |
| args = parser.parse_args() |
| |
| |
| ids = read_id_list(args.id_file) |
| print(f"[INFO] 从 {args.id_file} 读取到 {len(ids)} 个ID") |
| |
| if not ids: |
| print("[ERROR] 没有找到有效的ID", file=sys.stderr) |
| sys.exit(1) |
| |
| |
| combine_mode = args.pred_1fps is not None |
| if combine_mode: |
| if len(args.pred_1fps) != len(args.pred): |
| print(f"[ERROR] rollout_1fps 的预测目录数量 ({len(args.pred_1fps)}) 与 rollout_4fps ({len(args.pred)}) 不匹配", file=sys.stderr) |
| sys.exit(1) |
| |
| |
| args_dict = { |
| 'gt': args.gt, |
| 'pred': args.pred, |
| 'out_dir': args.out_dir, |
| 'labels': args.labels, |
| 'labels_combined': args.labels if combine_mode and len(args.labels) == (1 + len(args.pred) + len(args.pred_1fps)) else None, |
| 'end': args.end, |
| 'start': args.start, |
| 'fps': args.fps, |
| 'fps_1fps': args.fps_1fps, |
| 'gt_1fps': args.gt_1fps, |
| 'pred_1fps': args.pred_1fps, |
| 'safe_even': args.safe_even, |
| 'work_dir': args.work_dir or os.getcwd(), |
| 'dry_run': args.dry_run, |
| } |
| |
| |
| if combine_mode: |
| expected_labels = 1 + len(args.pred) + len(args.pred_1fps) |
| if len(args.labels) != expected_labels: |
| print(f"[WARN] 标签数量 ({len(args.labels)}) 与预期 ({expected_labels}) 不匹配") |
| print(f" 预期: GT + {len(args.pred)} 个 4fps 预测方法 + {len(args.pred_1fps)} 个 1fps 预测方法") |
| else: |
| expected_labels = 1 + len(args.pred) |
| if len(args.labels) != expected_labels: |
| print(f"[WARN] 标签数量 ({len(args.labels)}) 与预期 ({expected_labels}) 不匹配") |
| print(f" 预期: GT + {len(args.pred)} 个预测方法") |
| |
| |
| os.makedirs(args.out_dir, exist_ok=True) |
| |
| |
| success_count = 0 |
| skip_count = 0 |
| fail_count = 0 |
| total_tasks = len(ids) |
| current_task = 0 |
| |
| if combine_mode: |
| print(f"\n开始批量处理 {len(ids)} 个ID(合并模式:4fps + 1fps)...") |
| else: |
| print(f"\n开始批量处理 {len(ids)} 个ID(仅 4fps)...") |
| print(f"输出目录: {args.out_dir}") |
| print(f"并行数: {args.max_workers}") |
| print(f"跳过已存在: {args.skip_existing}") |
| print(f"总任务数: {total_tasks}\n") |
| |
| for i, id_val in enumerate(ids, 1): |
| current_task += 1 |
| |
| |
| if args.skip_existing: |
| video_path = os.path.join(args.out_dir, f"id_{id_val}", f"compare_id{id_val}.mp4") |
| if os.path.exists(video_path): |
| print(f"[SKIP] id_{id_val} 已存在,跳过 ({current_task}/{total_tasks})") |
| skip_count += 1 |
| continue |
| |
| print(f"\n[{current_task}/{total_tasks}] 处理 id_{id_val}...") |
| if combine_mode: |
| success = run_make_compare_combined(id_val, args_dict, args.dry_run) |
| else: |
| |
| cmd = [ |
| sys.executable, |
| 'make_compare.py', |
| '--gt', args_dict['gt'], |
| '--pred'] + args_dict['pred'] + [ |
| '--id', str(id_val), |
| '--out_dir', args_dict['out_dir'], |
| '--end', str(args_dict['end']), |
| '--labels'] + args_dict['labels'] + [ |
| '--fps', str(int(args_dict['fps'])) if args_dict['fps'] == int(args_dict['fps']) else str(args_dict['fps']), |
| ] |
| if args_dict.get('safe_even', False): |
| cmd.append('--safe_even') |
| if args_dict.get('start', None): |
| cmd.extend(['--start', str(args_dict['start'])]) |
| |
| if args_dict.get('dry_run', False): |
| print(f"[DRY RUN] 命令: {' '.join(cmd)}") |
| success = True |
| else: |
| try: |
| result = subprocess.run(cmd, check=True, cwd=args_dict.get('work_dir', '.')) |
| success = True |
| except Exception as e: |
| print(f"[ERROR] id_{id_val} 失败: {e}", file=sys.stderr) |
| success = False |
| |
| if success: |
| success_count += 1 |
| else: |
| fail_count += 1 |
| if not args.dry_run: |
| |
| response = input(f"\nid_{id_val} 失败,是否继续?(y/n): ").strip().lower() |
| if response != 'y': |
| print("用户取消") |
| return |
| |
| |
| print(f"\n{'='*80}") |
| print("批量处理完成!") |
| if combine_mode: |
| print(f"总计: {len(ids)} 个ID(每个包含 4fps + 1fps 合并视频)") |
| else: |
| print(f"总计: {len(ids)} 个ID(仅 4fps)") |
| print(f"成功: {success_count} 个") |
| print(f"跳过: {skip_count} 个") |
| print(f"失败: {fail_count} 个") |
| print(f"{'='*80}") |
|
|
| if __name__ == '__main__': |
| main() |
|
|
|
|