Spatial-BEATs / scripts /analyze_csv_dump.py
dieKarotte's picture
Add files using upload-large-folder tool
29615e9 verified
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import csv
import json
import math
import statistics
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any
def angular_distance_deg(azi1: float, ele1: float, azi2: float, ele2: float) -> float:
a1 = math.radians(azi1)
e1 = math.radians(ele1)
a2 = math.radians(azi2)
e2 = math.radians(ele2)
x1 = math.cos(e1) * math.cos(a1)
y1 = math.cos(e1) * math.sin(a1)
z1 = math.sin(e1)
x2 = math.cos(e2) * math.cos(a2)
y2 = math.cos(e2) * math.sin(a2)
z2 = math.sin(e2)
dot = max(-1.0, min(1.0, x1 * x2 + y1 * y2 + z1 * z2))
return math.degrees(math.acos(dot))
def infer_split(name: str) -> str | None:
prefixes = (
("valid__ov1_real_", "real_ov1"),
("valid__ov2_real_", "real_ov2"),
("valid__ov3_real_", "real_ov3"),
("valid__ov1_", "ov1"),
("valid__ov2_", "ov2"),
("valid__ov3_", "ov3"),
("valid__hm3d__", "ov1"),
)
for prefix, split in prefixes:
if name.startswith(prefix):
return split
return None
def load_frame_rows(csv_path: Path, threshold: float | None) -> dict[int, list[dict[str, Any]]]:
rows_by_frame: dict[int, list[dict[str, Any]]] = defaultdict(list)
with csv_path.open() as f:
for row in csv.DictReader(f):
if threshold is not None and float(row["activity_prob"]) < threshold:
continue
rows_by_frame[int(row["frame_idx"])].append(
{
"class_idx": int(row["class_idx"]),
"class_name": row["class_name"],
"azi": float(row["azimuth_deg"]),
"ele": float(row["elevation_deg"]),
"activity_prob": float(row["activity_prob"]),
"track_or_src": int(row["src_or_track_idx"]),
}
)
return rows_by_frame
def analyze_one_pair(pred_path: Path, gt_path: Path, threshold: float | None) -> dict[str, Any]:
pred_by_frame = load_frame_rows(pred_path, threshold=threshold)
gt_by_frame = load_frame_rows(gt_path, threshold=None)
all_frames = sorted(set(pred_by_frame) | set(gt_by_frame))
gt_outcomes = Counter()
pred_outcomes = Counter()
frame_relation = Counter()
same_class_best_angles: list[float] = []
for frame_idx in all_frames:
preds = pred_by_frame.get(frame_idx, [])
gts = gt_by_frame.get(frame_idx, [])
num_gt = len(gts)
num_pred = len(preds)
if num_pred < num_gt:
frame_relation["under"] += 1
elif num_pred == num_gt:
frame_relation["equal"] += 1
else:
frame_relation["over"] += 1
if num_gt > 0 and num_pred == 0:
frame_relation["gt_no_pred"] += 1
for gt in gts:
same_class_preds = [pred for pred in preds if pred["class_idx"] == gt["class_idx"]]
if same_class_preds:
best_angle = min(
angular_distance_deg(gt["azi"], gt["ele"], pred["azi"], pred["ele"])
for pred in same_class_preds
)
same_class_best_angles.append(best_angle)
if best_angle <= 20.0:
gt_outcomes["hit_cls_and_angle"] += 1
else:
gt_outcomes["class_right_angle_wrong"] += 1
else:
if preds:
gt_outcomes["no_same_class_pred_but_other_preds_exist"] += 1
else:
gt_outcomes["no_pred_in_frame"] += 1
used_pred = [False] * len(preds)
used_gt = [False] * len(gts)
candidates: list[tuple[float, int, int]] = []
for pred_idx, pred in enumerate(preds):
for gt_idx, gt in enumerate(gts):
if pred["class_idx"] != gt["class_idx"]:
continue
angle = angular_distance_deg(gt["azi"], gt["ele"], pred["azi"], pred["ele"])
if angle <= 20.0:
candidates.append((angle, pred_idx, gt_idx))
candidates.sort()
for _, pred_idx, gt_idx in candidates:
if used_pred[pred_idx] or used_gt[gt_idx]:
continue
used_pred[pred_idx] = True
used_gt[gt_idx] = True
pred_outcomes["matched_tp"] += 1
for pred_idx, pred in enumerate(preds):
if used_pred[pred_idx]:
continue
same_class_gt = [gt for gt in gts if gt["class_idx"] == pred["class_idx"]]
if same_class_gt:
pred_outcomes["same_class_angle_wrong_fp"] += 1
else:
pred_outcomes["wrong_class_or_spurious_fp"] += 1
return {
"file": pred_path.name,
"frames": len(all_frames),
"avg_gt_per_frame": (
sum(len(gt_by_frame.get(t, [])) for t in all_frames) / len(all_frames) if all_frames else 0.0
),
"avg_pred_per_frame": (
sum(len(pred_by_frame.get(t, [])) for t in all_frames) / len(all_frames) if all_frames else 0.0
),
"frame_relation": frame_relation,
"gt_outcomes": gt_outcomes,
"pred_outcomes": pred_outcomes,
"mean_same_class_best_angle": (
statistics.mean(same_class_best_angles) if same_class_best_angles else None
),
}
def aggregate_rows(rows: list[dict[str, Any]]) -> dict[str, Any]:
agg_gt = Counter()
agg_pred = Counter()
agg_frame = Counter()
total_frames = sum(row["frames"] for row in rows)
avg_gt = (
sum(row["avg_gt_per_frame"] * row["frames"] for row in rows) / total_frames if total_frames else 0.0
)
avg_pred = (
sum(row["avg_pred_per_frame"] * row["frames"] for row in rows) / total_frames if total_frames else 0.0
)
same_class_means = []
for row in rows:
agg_gt.update(row["gt_outcomes"])
agg_pred.update(row["pred_outcomes"])
agg_frame.update(row["frame_relation"])
if row["mean_same_class_best_angle"] is not None:
same_class_means.append(row["mean_same_class_best_angle"])
total_gt = sum(agg_gt.values())
total_pred = sum(agg_pred.values())
same_class_total = agg_gt["hit_cls_and_angle"] + agg_gt["class_right_angle_wrong"]
return {
"samples": len(rows),
"frames": total_frames,
"avg_gt_per_frame": avg_gt,
"avg_pred_per_frame": avg_pred,
"frame_relation": dict(agg_frame),
"gt_outcomes": dict(agg_gt),
"pred_outcomes": dict(agg_pred),
"gt_total": total_gt,
"pred_total": total_pred,
"same_class_angle_le_20_share": (
agg_gt["hit_cls_and_angle"] / same_class_total if same_class_total else None
),
"mean_best_angle_when_same_class_exists": (
statistics.mean(same_class_means) if same_class_means else None
),
"worst_under_predicted": [
{
"file": row["file"],
"avg_pred_per_frame": row["avg_pred_per_frame"],
"avg_gt_per_frame": row["avg_gt_per_frame"],
}
for row in sorted(rows, key=lambda row: row["avg_pred_per_frame"] - row["avg_gt_per_frame"])[:3]
],
}
def format_pct(numerator: int, denominator: int) -> str:
if denominator <= 0:
return "0.0%"
return f"{100.0 * numerator / denominator:.1f}%"
def print_summary(threshold: float | None, summary: dict[str, dict[str, Any]]) -> None:
thr_label = "raw_all_tracks" if threshold is None else f"activity>={threshold:g}"
print(f"=== mode: {thr_label} ===")
for split, stats in summary.items():
if stats["samples"] == 0:
continue
print(f"--- {split} ---")
print(
f"samples={stats['samples']} frames={stats['frames']} "
f"avg_gt/frame={stats['avg_gt_per_frame']:.2f} avg_pred/frame={stats['avg_pred_per_frame']:.2f}"
)
frame_rel = stats["frame_relation"]
print(
"frame_rel "
f"under={frame_rel.get('under', 0)} "
f"equal={frame_rel.get('equal', 0)} "
f"over={frame_rel.get('over', 0)} "
f"gt_no_pred={frame_rel.get('gt_no_pred', 0)}"
)
print("GT-side:")
for key in (
"hit_cls_and_angle",
"class_right_angle_wrong",
"no_same_class_pred_but_other_preds_exist",
"no_pred_in_frame",
):
value = stats["gt_outcomes"].get(key, 0)
print(f" {key}: {value} ({format_pct(value, stats['gt_total'])})")
if stats["same_class_angle_le_20_share"] is not None:
print(
" among GTs with same-class pred, angle<=20 share: "
f"{100.0 * stats['same_class_angle_le_20_share']:.1f}%"
)
if stats["mean_best_angle_when_same_class_exists"] is not None:
print(
" mean best angle when same-class pred exists: "
f"{stats['mean_best_angle_when_same_class_exists']:.2f}°"
)
print("Pred-side:")
for key in ("matched_tp", "same_class_angle_wrong_fp", "wrong_class_or_spurious_fp"):
value = stats["pred_outcomes"].get(key, 0)
print(f" {key}: {value} ({format_pct(value, stats['pred_total'])})")
print(" worst under-predicted samples:")
for row in stats["worst_under_predicted"]:
print(
f" {row['file']}: avg_pred={row['avg_pred_per_frame']:.2f} "
f"avg_gt={row['avg_gt_per_frame']:.2f}"
)
print()
def main() -> None:
parser = argparse.ArgumentParser(description="Analyze dumped __pred.csv / __gt.csv frame-track outputs.")
parser.add_argument(
"--dump-dir",
type=Path,
required=True,
help="Directory containing paired *__pred.csv and *__gt.csv files.",
)
parser.add_argument(
"--threshold",
type=float,
default=None,
help="Activity threshold. Omit to analyze raw all-track outputs.",
)
parser.add_argument(
"--threshold-sweep",
type=float,
nargs="*",
default=None,
help="Optional thresholds to analyze in addition to --threshold.",
)
parser.add_argument(
"--json-out",
type=Path,
default=None,
help="Optional path to write the aggregated result as JSON.",
)
args = parser.parse_args()
thresholds = []
if args.threshold is not None:
thresholds.append(args.threshold)
else:
thresholds.append(None)
if args.threshold_sweep:
thresholds.extend(args.threshold_sweep)
json_payload: dict[str, Any] = {
"dump_dir": str(args.dump_dir),
"results": {},
}
for threshold in thresholds:
rows_by_split: dict[str, list[dict[str, Any]]] = defaultdict(list)
for pred_path in sorted(args.dump_dir.glob("*__pred.csv")):
split = infer_split(pred_path.name)
if split is None:
continue
gt_path = Path(str(pred_path).replace("__pred.csv", "__gt.csv"))
rows_by_split[split].append(analyze_one_pair(pred_path, gt_path, threshold=threshold))
summary = {split: aggregate_rows(rows) for split, rows in sorted(rows_by_split.items())}
print_summary(threshold, summary)
thr_key = "raw_all_tracks" if threshold is None else f"thr_{threshold:g}"
json_payload["results"][thr_key] = summary
if args.json_out is not None:
args.json_out.write_text(json.dumps(json_payload, indent=2, ensure_ascii=False))
if __name__ == "__main__":
main()