Spaces:
Sleeping
Sleeping
feat: implement cross-validated predictions caching and update confusion matrix function
Browse files
app/training/generate_figures.py
CHANGED
|
@@ -94,14 +94,24 @@ def _load_csv_data(feature_cols):
|
|
| 94 |
return X, y
|
| 95 |
|
| 96 |
|
| 97 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
"""Confusion matrix for the best model (CV predictions)."""
|
| 99 |
best = results.get("_best_model", "XGBoost")
|
| 100 |
-
data = results.get(best)
|
| 101 |
-
if not data:
|
| 102 |
-
return
|
| 103 |
-
y_true = np.array(data["y_true"])
|
| 104 |
-
y_pred = np.array(data["y_pred"])
|
| 105 |
cm = confusion_matrix(y_true, y_pred)
|
| 106 |
|
| 107 |
fig, ax = plt.subplots(figsize=(6.5, 5.5))
|
|
@@ -109,9 +119,12 @@ def fig_confusion_matrix(results: dict) -> None:
|
|
| 109 |
"auris", [PALETTE["bg"], PALETTE["primary"]],
|
| 110 |
)
|
| 111 |
im = ax.imshow(cm, cmap=cmap, aspect="auto")
|
|
|
|
|
|
|
|
|
|
| 112 |
ax.set_title(
|
| 113 |
f"Karışıklık Matrisi — {best}\n"
|
| 114 |
-
f"Accuracy: {
|
| 115 |
fontsize=13, fontweight="bold",
|
| 116 |
)
|
| 117 |
classes = ["İnsan / Human", "AI"]
|
|
|
|
| 94 |
return X, y
|
| 95 |
|
| 96 |
|
| 97 |
+
def _get_cv_predictions(model, X_scaled, y, cache: dict) -> tuple:
|
| 98 |
+
"""Cross-validated predictions with caching across figures."""
|
| 99 |
+
key = id(model)
|
| 100 |
+
if key in cache:
|
| 101 |
+
return cache[key]
|
| 102 |
+
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
|
| 103 |
+
y_prob = cross_val_predict(
|
| 104 |
+
clone(model), X_scaled, y, cv=cv, method="predict_proba", n_jobs=-1,
|
| 105 |
+
)[:, 1]
|
| 106 |
+
y_pred = (y_prob > 0.5).astype(int)
|
| 107 |
+
cache[key] = (y, y_pred, y_prob)
|
| 108 |
+
return y, y_pred, y_prob
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def fig_confusion_matrix(results: dict, y_true: np.ndarray, y_pred: np.ndarray) -> None:
|
| 112 |
"""Confusion matrix for the best model (CV predictions)."""
|
| 113 |
best = results.get("_best_model", "XGBoost")
|
| 114 |
+
data = results.get(best, {})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
cm = confusion_matrix(y_true, y_pred)
|
| 116 |
|
| 117 |
fig, ax = plt.subplots(figsize=(6.5, 5.5))
|
|
|
|
| 119 |
"auris", [PALETTE["bg"], PALETTE["primary"]],
|
| 120 |
)
|
| 121 |
im = ax.imshow(cm, cmap=cmap, aspect="auto")
|
| 122 |
+
acc = data.get("accuracy", (y_true == y_pred).mean())
|
| 123 |
+
f1v = data.get("f1", 0.0)
|
| 124 |
+
aucv = data.get("roc_auc", 0.0)
|
| 125 |
ax.set_title(
|
| 126 |
f"Karışıklık Matrisi — {best}\n"
|
| 127 |
+
f"Accuracy: {acc:.1%} F1: {f1v:.3f} AUC: {aucv:.3f}",
|
| 128 |
fontsize=13, fontweight="bold",
|
| 129 |
)
|
| 130 |
classes = ["İnsan / Human", "AI"]
|