Rthur2003 commited on
Commit
0ea8f54
·
1 Parent(s): 853f480

feat: implement cross-validated predictions caching and update confusion matrix function

Browse files
Files changed (1) hide show
  1. app/training/generate_figures.py +20 -7
app/training/generate_figures.py CHANGED
@@ -94,14 +94,24 @@ def _load_csv_data(feature_cols):
94
  return X, y
95
 
96
 
97
- def fig_confusion_matrix(results: dict) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  """Confusion matrix for the best model (CV predictions)."""
99
  best = results.get("_best_model", "XGBoost")
100
- data = results.get(best)
101
- if not data:
102
- return
103
- y_true = np.array(data["y_true"])
104
- y_pred = np.array(data["y_pred"])
105
  cm = confusion_matrix(y_true, y_pred)
106
 
107
  fig, ax = plt.subplots(figsize=(6.5, 5.5))
@@ -109,9 +119,12 @@ def fig_confusion_matrix(results: dict) -> None:
109
  "auris", [PALETTE["bg"], PALETTE["primary"]],
110
  )
111
  im = ax.imshow(cm, cmap=cmap, aspect="auto")
 
 
 
112
  ax.set_title(
113
  f"Karışıklık Matrisi — {best}\n"
114
- f"Accuracy: {data['accuracy']:.1%} F1: {data['f1']:.3f} AUC: {data['roc_auc']:.3f}",
115
  fontsize=13, fontweight="bold",
116
  )
117
  classes = ["İnsan / Human", "AI"]
 
94
  return X, y
95
 
96
 
97
+ def _get_cv_predictions(model, X_scaled, y, cache: dict) -> tuple:
98
+ """Cross-validated predictions with caching across figures."""
99
+ key = id(model)
100
+ if key in cache:
101
+ return cache[key]
102
+ cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
103
+ y_prob = cross_val_predict(
104
+ clone(model), X_scaled, y, cv=cv, method="predict_proba", n_jobs=-1,
105
+ )[:, 1]
106
+ y_pred = (y_prob > 0.5).astype(int)
107
+ cache[key] = (y, y_pred, y_prob)
108
+ return y, y_pred, y_prob
109
+
110
+
111
+ def fig_confusion_matrix(results: dict, y_true: np.ndarray, y_pred: np.ndarray) -> None:
112
  """Confusion matrix for the best model (CV predictions)."""
113
  best = results.get("_best_model", "XGBoost")
114
+ data = results.get(best, {})
 
 
 
 
115
  cm = confusion_matrix(y_true, y_pred)
116
 
117
  fig, ax = plt.subplots(figsize=(6.5, 5.5))
 
119
  "auris", [PALETTE["bg"], PALETTE["primary"]],
120
  )
121
  im = ax.imshow(cm, cmap=cmap, aspect="auto")
122
+ acc = data.get("accuracy", (y_true == y_pred).mean())
123
+ f1v = data.get("f1", 0.0)
124
+ aucv = data.get("roc_auc", 0.0)
125
  ax.set_title(
126
  f"Karışıklık Matrisi — {best}\n"
127
+ f"Accuracy: {acc:.1%} F1: {f1v:.3f} AUC: {aucv:.3f}",
128
  fontsize=13, fontweight="bold",
129
  )
130
  classes = ["İnsan / Human", "AI"]