ai-internet-diagnostic-model / tests /test_train_classifier.py
WolfDavid's picture
feat(02-01): feature matrix assembly + calibrated LightGBM classifier
10cb234
"""Calibrated classifier tests (D-CAL-01..03; OQ-2 single-wrap)."""
from pathlib import Path
import numpy as np
import pytest
from lightgbm import LGBMClassifier
from sklearn.calibration import CalibratedClassifierCV
from model.features import CLASSES, load_split
from model.train_classifier import LGBM_PARAMS, train_calibrated_classifier
def _smoke_subset(
X: np.ndarray, y: np.ndarray, per_class: int = 100
) -> tuple[np.ndarray, np.ndarray]:
"""Take first `per_class` rows of each class — fast smoke fixture.
With cv=5 stratified k-fold we need >=5 samples per class; 100 is plenty
and keeps test wall time <30s.
"""
idx_list: list[int] = []
for c in range(len(CLASSES)):
class_idx = np.where(y == c)[0][:per_class]
idx_list.extend(class_idx.tolist())
idx = np.array(idx_list)
return X[idx], y[idx]
def test_lgbm_params_pin_determinism_trio() -> None:
"""Pitfall 1 — random_state + deterministic + force_col_wise required for 1e-4."""
# random_state is passed as kwarg, not in LGBM_PARAMS, but check the trio companion flags.
assert LGBM_PARAMS["deterministic"] is True
assert LGBM_PARAMS["force_col_wise"] is True
assert LGBM_PARAMS["class_weight"] == "balanced"
assert LGBM_PARAMS["n_estimators"] == 500
assert LGBM_PARAMS["learning_rate"] == 0.05
assert LGBM_PARAMS["num_leaves"] == 63
@pytest.mark.skipif(
not Path("data/train.parquet").exists(),
reason="data/train.parquet not generated (run `make synth` first)",
)
def test_classifier_predicts_all_10_classes() -> None:
"""CLASS-01: argmax over calibrated probs returns one of the 10 CLASSES."""
X, y, _ = load_split(Path("data/train.parquet"))
Xs, ys = _smoke_subset(X, y, per_class=100)
clf = train_calibrated_classifier(Xs, ys, classifier_seed=42, cv_seed=43)
preds = clf.predict(Xs)
assert preds.min() >= 0
assert preds.max() < len(CLASSES)
assert len(set(preds.tolist())) >= 5 # at least 5 distinct classes predicted on smoke set
@pytest.mark.skipif(
not Path("data/train.parquet").exists(),
reason="data/train.parquet not generated (run `make synth` first)",
)
def test_predict_proba_rows_sum_to_one() -> None:
"""CLASS-02: calibrated probability simplex."""
X, y, _ = load_split(Path("data/train.parquet"))
Xs, ys = _smoke_subset(X, y, per_class=100)
clf = train_calibrated_classifier(Xs, ys, classifier_seed=42, cv_seed=43)
proba = clf.predict_proba(Xs)
assert proba.shape == (len(Xs), len(CLASSES))
np.testing.assert_allclose(proba.sum(axis=1), 1.0, atol=1e-9)
@pytest.mark.skipif(
not Path("data/train.parquet").exists(),
reason="data/train.parquet not generated (run `make synth` first)",
)
def test_no_double_ovr_pitfall_4() -> None:
"""Pitfall 4 / OQ-2: single-wrap. Inner estimator is LGBMClassifier, not OvR."""
X, y, _ = load_split(Path("data/train.parquet"))
Xs, ys = _smoke_subset(X, y, per_class=100)
clf = train_calibrated_classifier(Xs, ys, classifier_seed=42, cv_seed=43)
assert isinstance(clf, CalibratedClassifierCV)
assert len(clf.calibrated_classifiers_) == 5 # cv=5 ensemble
inner = clf.calibrated_classifiers_[0].estimator
assert isinstance(inner, LGBMClassifier), (
f"Pitfall 4: inner estimator should be LGBMClassifier (single-wrap), "
f"got {type(inner).__name__}"
)