WolfDavid's picture
feat(02-03): implement per-class ECE/Brier + eval_metrics builder
eefdcf1
"""Eval metrics tests (D-CAL-04..07, D-MASK-04, Pitfalls 8 + 12)."""
import json
import numpy as np
from model.eval import (
UNIFORM_PRIOR_PROB,
build_eval_metrics,
per_class_brier,
per_class_brier_baseline,
per_class_classification_report,
per_class_ece,
per_mode_macro_f1,
)
from model.features import CLASSES
def _well_calibrated_synthetic(n_per_class: int = 200) -> tuple[np.ndarray, np.ndarray]:
"""Generate (y_true, y_proba) where probs are close to empirical frequencies."""
rng = np.random.default_rng(42)
n = n_per_class * len(CLASSES)
y = np.repeat(np.arange(len(CLASSES)), n_per_class)
# For each true label, give it 0.7 confidence + 0.03 noise on the others (sums to 1)
proba = np.zeros((n, len(CLASSES)))
for i in range(n):
proba[i, :] = 0.03333 # 0.3 / 9 spread on non-true classes
proba[i, y[i]] = 0.70
# renormalize
proba[i, :] /= proba[i, :].sum()
proba[i, :] += rng.normal(0, 0.005, len(CLASSES))
proba[i, :] = np.clip(proba[i, :], 1e-6, None)
proba[i, :] /= proba[i, :].sum()
return y, proba
def test_uniform_prior_prob() -> None:
"""D-CAL-05: 0.1 for 10 classes."""
assert abs(UNIFORM_PRIOR_PROB - 0.1) < 1e-12
def test_per_class_ece_finite_and_in_unit_interval() -> None:
"""CLASS-02 / D-CAL-07: all 10 ECEs are floats in [0, 1]."""
y, proba = _well_calibrated_synthetic()
ece = per_class_ece(y, proba, n_bins=10)
assert set(ece.keys()) == set(CLASSES)
for slug, v in ece.items():
assert isinstance(v, float) # Pitfall 8 — Python float not np.float64
assert 0.0 <= v <= 1.0, f"{slug} ECE out of range: {v}"
def test_per_class_ece_n_bins_default_is_ten() -> None:
"""D-CAL-07: 10 equal-width bins is the default."""
y, proba = _well_calibrated_synthetic()
# Calling without n_bins should use n_bins=10
a = per_class_ece(y, proba)
b = per_class_ece(y, proba, n_bins=10)
assert a == b
def test_per_class_ece_finite_and_reasonable_on_well_calibrated() -> None:
"""CLASS-02: well-calibrated synthetic data has bounded ECE.
The _well_calibrated_synthetic generator emits 0.70 confidence on the true
class but the model is only 70% accurate by construction, so per-class ECE
sits in the 0.20-0.25 band — tighter ECE bounds are tested by the
bad-vs-calibrated comparison test below.
"""
y, proba = _well_calibrated_synthetic(n_per_class=500)
ece = per_class_ece(y, proba, n_bins=10)
mean_ece = np.mean(list(ece.values()))
assert mean_ece < 0.30, f"well-calibrated synthetic should have bounded ECE; got {mean_ece}"
def test_calibrated_ece_lower_than_raw_signal() -> None:
"""CLASS-02 (D-CAL-06 portfolio signal): calibrated probs have lower mean ECE
than badly-miscalibrated probs.
Construction: ground-truth predictor is right ~70% of the time. The "bad"
variant emits 0.99 confidence on every prediction (wild overconfidence on
the 30% wrong cases — high-confidence mistakes drive a big calibration gap).
The "calibrated" variant emits 0.70 confidence on every prediction —
matching the empirical accuracy. Calibration error is structurally lower
on the calibrated matrix because predicted-confidence ~ observed-accuracy.
"""
rng = np.random.default_rng(0)
n = 3000
y = rng.integers(0, len(CLASSES), n)
# 70% correct predictions, 30% wrong — reflects an empirical accuracy of 0.7.
correct = rng.random(n) < 0.7
pred = np.where(
correct,
y,
(y + rng.integers(1, len(CLASSES), n)) % len(CLASSES),
)
# Bad: always 0.99 on whichever class the model picked — overconfident.
bad = np.full((n, len(CLASSES)), (1.0 - 0.99) / (len(CLASSES) - 1))
for i in range(n):
bad[i, pred[i]] = 0.99
bad[i, :] /= bad[i, :].sum()
# Calibrated: 0.70 on whichever class the model picked — matches empirical accuracy.
cal = np.full((n, len(CLASSES)), (1.0 - 0.70) / (len(CLASSES) - 1))
for i in range(n):
cal[i, pred[i]] = 0.70
cal[i, :] /= cal[i, :].sum()
ece_bad = np.mean(list(per_class_ece(y, bad).values()))
ece_cal = np.mean(list(per_class_ece(y, cal).values()))
assert ece_cal < ece_bad, (
f"Calibrated ECE ({ece_cal}) should be lower than bad ECE ({ece_bad}) "
"— D-CAL-06 portfolio signal"
)
def test_per_class_brier_finite() -> None:
"""D-CAL-04: per-class Brier is float in [0, 1]."""
y, proba = _well_calibrated_synthetic()
brier = per_class_brier(y, proba)
assert set(brier.keys()) == set(CLASSES)
for v in brier.values():
assert isinstance(v, float)
assert 0.0 <= v <= 1.0
def test_per_class_brier_baseline_uniform_prior() -> None:
"""D-CAL-05: baseline = brier with uniform 0.1 predictions."""
rng = np.random.default_rng(0)
y = rng.integers(0, len(CLASSES), 500)
baseline = per_class_brier_baseline(y)
# For each class c: y_binary mean ~ 0.1 (uniform y); pred is constant 0.1.
# Brier ≈ E[(0.1 - {0,1})^2] = 0.9 * 0.01 + 0.1 * 0.81 = 0.09 (when prevalence is 0.1)
for slug, v in baseline.items():
assert isinstance(v, float)
assert 0.0 <= v <= 1.0
# Loose bound — close to 0.09 for balanced 1-of-10
assert abs(v - 0.09) < 0.05, f"{slug} baseline Brier = {v}, expected ~0.09"
def test_per_class_classification_report_python_types() -> None:
"""Pitfall 8: every leaf value is float / int (NOT np.float64 / np.int64)."""
rng = np.random.default_rng(0)
n = 1000
y_true = rng.integers(0, len(CLASSES), n)
y_pred = rng.integers(0, len(CLASSES), n)
report = per_class_classification_report(y_true, y_pred)
for slug in CLASSES:
assert isinstance(report[slug]["precision"], float)
assert isinstance(report[slug]["recall"], float)
assert isinstance(report[slug]["f1"], float)
assert isinstance(report[slug]["support"], int)
def test_per_mode_macro_f1_subsets_by_actual_mode() -> None:
"""D-MASK-04 / OQ-4: subset by actual `network_mode` column, NOT mask-uniform."""
rng = np.random.default_rng(0)
n = 400
y_true = rng.integers(0, len(CLASSES), n)
y_pred = y_true.copy()
# Inject errors: 50% wrong on enterprise rows, 0% wrong on home rows.
modes = np.array(["enterprise"] * (n // 2) + ["home"] * (n // 2))
wrong_idx = np.where(modes == "enterprise")[0][:len(modes[modes == "enterprise"]) // 2]
for i in wrong_idx:
y_pred[i] = (y_true[i] + 1) % len(CLASSES)
f1_by_mode = per_mode_macro_f1(y_true, y_pred, modes)
assert set(f1_by_mode.keys()) == {"enterprise", "captive", "home", "unknown"}
# Home is perfect: F1 should be ~1.0
assert f1_by_mode["home"] > 0.95, f"home F1 = {f1_by_mode['home']}"
# Enterprise has 50% errors: F1 should be much lower
assert f1_by_mode["enterprise"] < f1_by_mode["home"]
# Captive + unknown are absent rows: F1 should be 0.0 (no support)
assert f1_by_mode["captive"] == 0.0
assert f1_by_mode["unknown"] == 0.0
def test_build_eval_metrics_json_serializable() -> None:
"""Pitfall 8: build_eval_metrics output must json.dumps cleanly."""
rng = np.random.default_rng(0)
n = 200
y_eval = rng.integers(0, len(CLASSES), n)
calibrated_proba = rng.dirichlet(np.ones(len(CLASSES)), size=n)
y_pred = np.argmax(calibrated_proba, axis=1)
modes = rng.choice(["enterprise", "captive", "home", "unknown"], n)
metrics = build_eval_metrics(
y_eval=y_eval,
calibrated_proba=calibrated_proba,
y_pred_after_mask=y_pred,
network_mode_per_row=modes,
anomaly_threshold=0.04,
per_class_lead_times={slug: np.array([5.0, 10.0]) for slug in CLASSES},
per_class_miss_rates={slug: 0.05 for slug in CLASSES},
)
# The acid test: does json.dumps work?
s = json.dumps(metrics, sort_keys=True)
assert isinstance(s, str)
# Schema sanity:
assert "schema_version" in metrics
assert "macro_f1" in metrics
assert "ece_mean" in metrics
assert "per_class" in metrics
assert set(metrics["per_class"].keys()) == set(CLASSES)
assert "anomaly" in metrics
assert "threshold_95p_normal" in metrics["anomaly"]
assert "lead_time_aggregate_median_s" in metrics["anomaly"]
assert "per_class_lead_time_median_s" in metrics["anomaly"]
assert "per_class_miss_rate" in metrics["anomaly"]
assert "by_network_mode_macro_f1" in metrics
assert set(metrics["by_network_mode_macro_f1"].keys()) == {
"enterprise", "captive", "home", "unknown",
}
def test_per_class_lead_time_non_negative() -> None:
"""ANOM-02: build_eval_metrics never produces negative per-class lead-time medians."""
rng = np.random.default_rng(0)
n = 100
y_eval = rng.integers(0, len(CLASSES), n)
calibrated_proba = rng.dirichlet(np.ones(len(CLASSES)), size=n)
y_pred = np.argmax(calibrated_proba, axis=1)
modes = rng.choice(["enterprise", "captive", "home", "unknown"], n)
# Sample lead-times with positive values (real D-ANOM-01 enforces >= 0)
lts = {slug: np.array([1.0, 5.0, 10.0]) for slug in CLASSES}
metrics = build_eval_metrics(
y_eval=y_eval, calibrated_proba=calibrated_proba,
y_pred_after_mask=y_pred, network_mode_per_row=modes,
anomaly_threshold=0.04, per_class_lead_times=lts,
per_class_miss_rates={slug: 0.0 for slug in CLASSES},
)
for slug in CLASSES:
assert metrics["anomaly"]["per_class_lead_time_median_s"][slug] >= 0.0
assert metrics["anomaly"]["lead_time_aggregate_median_s"] >= 0.0
def test_per_class_metrics_complete() -> None:
"""CLASS-06 (B-3 plan-checker fix): per_class dict has all 10 CLASSES with values in [0, 1]."""
rng = np.random.default_rng(0)
y = rng.integers(0, len(CLASSES), 200)
p = rng.dirichlet(np.ones(len(CLASSES)), size=200)
pred = np.argmax(p, axis=1)
modes = rng.choice(["enterprise", "captive", "home", "unknown"], 200)
m = build_eval_metrics(
y_eval=y, calibrated_proba=p, y_pred_after_mask=pred,
network_mode_per_row=modes, anomaly_threshold=0.04,
per_class_lead_times={s: np.array([1.0]) for s in CLASSES},
per_class_miss_rates={s: 0.0 for s in CLASSES},
)
assert set(m["per_class"].keys()) == set(CLASSES)
for slug in CLASSES:
for field in ("precision", "recall", "f1"):
v = m["per_class"][slug][field]
assert 0.0 <= v <= 1.0, f"{slug}.{field} = {v} out of [0,1]"