Rthur2003 commited on
Commit
94e94a1
·
1 Parent(s): 9e93437

feat: add figure generation script for training results visualization

Browse files
Files changed (1) hide show
  1. app/training/generate_figures.py +408 -0
app/training/generate_figures.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate academic-quality figures from training results.
2
+
3
+ Produces publication-ready figures in DataSet/figures/:
4
+ - confusion_matrix.png
5
+ - roc_curves_comparison.png
6
+ - precision_recall_curves.png
7
+ - feature_importance_top20.png
8
+ - calibration_plot.png
9
+ - feature_distribution_ai_vs_human.png
10
+ - shap_summary.png
11
+ - model_comparison_bars.png
12
+
13
+ Usage:
14
+ python -m app.training.generate_figures
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import csv
20
+ import json
21
+ import pickle
22
+ from pathlib import Path
23
+
24
+ import numpy as np
25
+ import matplotlib
26
+ matplotlib.use("Agg")
27
+ import matplotlib.pyplot as plt
28
+ from matplotlib.colors import LinearSegmentedColormap
29
+
30
+ from sklearn.metrics import (
31
+ confusion_matrix, roc_curve, auc,
32
+ precision_recall_curve, average_precision_score,
33
+ )
34
+ from sklearn.calibration import calibration_curve
35
+
36
+ # ── Paths ────────────────────────────────────────────────────────────────
37
+ BACKEND = Path(__file__).resolve().parents[2]
38
+ MODELS_DIR = BACKEND / "models"
39
+ DATASET_DIR = BACKEND.parent / "DataSet"
40
+ FIGURES_DIR = DATASET_DIR / "figures"
41
+ FEATURES_CSV = DATASET_DIR / "features.csv"
42
+
43
+ # ── Theme (AURIS parchment gold palette) ─────────────────────────────────
44
+ PALETTE = {
45
+ "bg": "#faf6ed",
46
+ "fg": "#3d2817",
47
+ "primary": "#c99347",
48
+ "secondary": "#7fb069",
49
+ "error": "#a64b3c",
50
+ "grid": "#d8c9a8",
51
+ "accent": "#e7c77a",
52
+ }
53
+
54
+ plt.rcParams.update({
55
+ "figure.facecolor": PALETTE["bg"],
56
+ "axes.facecolor": PALETTE["bg"],
57
+ "axes.edgecolor": PALETTE["fg"],
58
+ "axes.labelcolor": PALETTE["fg"],
59
+ "xtick.color": PALETTE["fg"],
60
+ "ytick.color": PALETTE["fg"],
61
+ "text.color": PALETTE["fg"],
62
+ "font.family": "DejaVu Sans",
63
+ "font.size": 11,
64
+ "axes.grid": True,
65
+ "grid.color": PALETTE["grid"],
66
+ "grid.alpha": 0.4,
67
+ "savefig.dpi": 150,
68
+ "savefig.bbox": "tight",
69
+ "figure.dpi": 100,
70
+ })
71
+
72
+
73
+ def _load_artifacts():
74
+ """Load training results, model, features CSV."""
75
+ with open(MODELS_DIR / "training_results.json", "r") as f:
76
+ results = json.load(f)
77
+ with open(MODELS_DIR / "auris_classifier_v1.pkl", "rb") as f:
78
+ model = pickle.load(f)
79
+ with open(MODELS_DIR / "feature_scaler_v1.pkl", "rb") as f:
80
+ scaler = pickle.load(f)
81
+ with open(MODELS_DIR / "feature_columns_v1.json", "r") as f:
82
+ feature_cols = json.load(f)
83
+ return results, model, scaler, feature_cols
84
+
85
+
86
+ def _load_csv_data(feature_cols):
87
+ with open(FEATURES_CSV, "r", encoding="utf-8") as f:
88
+ rows = list(csv.DictReader(f))
89
+ X = np.array([[float(r[c]) for c in feature_cols] for r in rows])
90
+ X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=-1.0)
91
+ y = np.array([int(r["label_int"]) for r in rows])
92
+ return X, y
93
+
94
+
95
+ def fig_confusion_matrix(results: dict) -> None:
96
+ """Confusion matrix for the best model (CV predictions)."""
97
+ best = results.get("_best_model", "XGBoost")
98
+ data = results.get(best)
99
+ if not data:
100
+ return
101
+ y_true = np.array(data["y_true"])
102
+ y_pred = np.array(data["y_pred"])
103
+ cm = confusion_matrix(y_true, y_pred)
104
+
105
+ fig, ax = plt.subplots(figsize=(6.5, 5.5))
106
+ cmap = LinearSegmentedColormap.from_list(
107
+ "auris", [PALETTE["bg"], PALETTE["primary"]],
108
+ )
109
+ im = ax.imshow(cm, cmap=cmap, aspect="auto")
110
+ ax.set_title(
111
+ f"Karışıklık Matrisi — {best}\n"
112
+ f"Accuracy: {data['accuracy']:.1%} F1: {data['f1']:.3f} AUC: {data['roc_auc']:.3f}",
113
+ fontsize=13, fontweight="bold",
114
+ )
115
+ classes = ["İnsan / Human", "AI"]
116
+ ax.set_xticks([0, 1])
117
+ ax.set_yticks([0, 1])
118
+ ax.set_xticklabels(classes)
119
+ ax.set_yticklabels(classes)
120
+ ax.set_xlabel("Tahmin / Predicted")
121
+ ax.set_ylabel("Gerçek / Actual")
122
+
123
+ # cell annotations with count + percentage
124
+ total = cm.sum()
125
+ for i in range(2):
126
+ for j in range(2):
127
+ count = cm[i, j]
128
+ pct = 100 * count / total
129
+ color = PALETTE["bg"] if count > total * 0.25 else PALETTE["fg"]
130
+ ax.text(
131
+ j, i, f"{count}\n({pct:.1f}%)",
132
+ ha="center", va="center",
133
+ color=color, fontsize=13, fontweight="bold",
134
+ )
135
+
136
+ plt.colorbar(im, ax=ax, shrink=0.7)
137
+ plt.savefig(FIGURES_DIR / "confusion_matrix.png")
138
+ plt.close()
139
+ print(" ✓ confusion_matrix.png")
140
+
141
+
142
+ def fig_roc_comparison(results: dict) -> None:
143
+ """All models ROC curves overlaid."""
144
+ fig, ax = plt.subplots(figsize=(8, 6.5))
145
+ colors = plt.cm.plasma(np.linspace(0.15, 0.85, 10))
146
+
147
+ best = results.get("_best_model", "XGBoost")
148
+ items = [(k, v) for k, v in results.items() if not k.startswith("_") and isinstance(v, dict)]
149
+ items.sort(key=lambda x: x[1].get("roc_auc", 0), reverse=True)
150
+
151
+ for idx, (name, data) in enumerate(items):
152
+ y_true = np.array(data["y_true"])
153
+ y_prob = np.array(data["y_prob"])
154
+ fpr, tpr, _ = roc_curve(y_true, y_prob)
155
+ roc_auc = auc(fpr, tpr)
156
+ lw = 3 if name == best else 1.5
157
+ ls = "-" if name == best else "--"
158
+ ax.plot(
159
+ fpr, tpr,
160
+ color=colors[idx],
161
+ lw=lw, ls=ls,
162
+ label=f"{name} (AUC = {roc_auc:.4f})",
163
+ )
164
+
165
+ ax.plot([0, 1], [0, 1], "k:", alpha=0.3, lw=1)
166
+ ax.set_xlabel("Yanlış Pozitif Oranı / False Positive Rate")
167
+ ax.set_ylabel("Doğru Pozitif Oranı / True Positive Rate")
168
+ ax.set_title("ROC Eğrileri — Model Karşılaştırması", fontsize=13, fontweight="bold")
169
+ ax.legend(loc="lower right", framealpha=0.85)
170
+ ax.set_xlim([0, 1])
171
+ ax.set_ylim([0, 1.02])
172
+ plt.savefig(FIGURES_DIR / "roc_curves_comparison.png")
173
+ plt.close()
174
+ print(" ✓ roc_curves_comparison.png")
175
+
176
+
177
+ def fig_pr_curves(results: dict) -> None:
178
+ """Precision-Recall curves — critical for imbalanced classes."""
179
+ fig, ax = plt.subplots(figsize=(8, 6.5))
180
+ colors = plt.cm.plasma(np.linspace(0.15, 0.85, 10))
181
+
182
+ best = results.get("_best_model", "XGBoost")
183
+ items = [(k, v) for k, v in results.items() if not k.startswith("_") and isinstance(v, dict)]
184
+ items.sort(key=lambda x: x[1].get("roc_auc", 0), reverse=True)
185
+
186
+ for idx, (name, data) in enumerate(items):
187
+ y_true = np.array(data["y_true"])
188
+ y_prob = np.array(data["y_prob"])
189
+ prec, rec, _ = precision_recall_curve(y_true, y_prob)
190
+ ap = average_precision_score(y_true, y_prob)
191
+ lw = 3 if name == best else 1.5
192
+ ls = "-" if name == best else "--"
193
+ ax.plot(
194
+ rec, prec,
195
+ color=colors[idx],
196
+ lw=lw, ls=ls,
197
+ label=f"{name} (AP = {ap:.4f})",
198
+ )
199
+
200
+ ax.set_xlabel("Duyarlılık / Recall")
201
+ ax.set_ylabel("Kesinlik / Precision")
202
+ ax.set_title("Precision-Recall Eğrileri", fontsize=13, fontweight="bold")
203
+ ax.legend(loc="lower left", framealpha=0.85)
204
+ ax.set_xlim([0, 1])
205
+ ax.set_ylim([0, 1.02])
206
+ plt.savefig(FIGURES_DIR / "precision_recall_curves.png")
207
+ plt.close()
208
+ print(" ✓ precision_recall_curves.png")
209
+
210
+
211
+ def fig_feature_importance(results: dict, top_n: int = 20) -> None:
212
+ """Top N feature importance bar chart."""
213
+ imp = results.get("_feature_importance", {})
214
+ if not imp:
215
+ return
216
+ items = sorted(imp.items(), key=lambda x: x[1], reverse=True)[:top_n]
217
+ names = [n for n, _ in items]
218
+ vals = [v for _, v in items]
219
+
220
+ fig, ax = plt.subplots(figsize=(9, 7))
221
+ y_pos = np.arange(len(names))
222
+ colors_grad = plt.cm.copper(np.linspace(0.3, 0.85, len(names)))
223
+ ax.barh(y_pos, vals, color=colors_grad, edgecolor=PALETTE["fg"], linewidth=0.5)
224
+ ax.set_yticks(y_pos)
225
+ ax.set_yticklabels(names, fontsize=10)
226
+ ax.invert_yaxis()
227
+ ax.set_xlabel("Normalize Önem / Normalized Importance")
228
+ ax.set_title(f"En Önemli {top_n} Özellik — {results.get('_best_model', 'XGBoost')}",
229
+ fontsize=13, fontweight="bold")
230
+ for i, v in enumerate(vals):
231
+ ax.text(v + max(vals) * 0.01, i, f"{v:.4f}", va="center", fontsize=8)
232
+ plt.savefig(FIGURES_DIR / "feature_importance_top20.png")
233
+ plt.close()
234
+ print(" ✓ feature_importance_top20.png")
235
+
236
+
237
+ def fig_calibration(results: dict) -> None:
238
+ """Calibration curve — does predicted probability match reality?"""
239
+ fig, ax = plt.subplots(figsize=(7, 6.5))
240
+ best = results.get("_best_model", "XGBoost")
241
+ items = [(k, v) for k, v in results.items() if not k.startswith("_") and isinstance(v, dict)]
242
+
243
+ colors = plt.cm.plasma(np.linspace(0.2, 0.8, len(items)))
244
+ for idx, (name, data) in enumerate(items):
245
+ y_true = np.array(data["y_true"])
246
+ y_prob = np.array(data["y_prob"])
247
+ frac_pos, mean_pred = calibration_curve(y_true, y_prob, n_bins=10)
248
+ lw = 3 if name == best else 1.2
249
+ ax.plot(mean_pred, frac_pos, "o-", color=colors[idx], lw=lw,
250
+ label=f"{name}", markersize=6 if name == best else 4)
251
+
252
+ ax.plot([0, 1], [0, 1], "k:", alpha=0.5, label="Mükemmel / Perfect")
253
+ ax.set_xlabel("Ortalama Tahmin Olasılığı / Mean Predicted Probability")
254
+ ax.set_ylabel("Gerçek Pozitif Oranı / Fraction of Positives")
255
+ ax.set_title("Kalibrasyon Eğrisi", fontsize=13, fontweight="bold")
256
+ ax.legend(loc="upper left", framealpha=0.85, fontsize=9)
257
+ plt.savefig(FIGURES_DIR / "calibration_plot.png")
258
+ plt.close()
259
+ print(" ✓ calibration_plot.png")
260
+
261
+
262
+ def fig_feature_distributions(feature_cols: list[str], top_features: list[str]) -> None:
263
+ """Distribution of top-8 features by AI vs Human."""
264
+ with open(FEATURES_CSV, "r", encoding="utf-8") as f:
265
+ rows = list(csv.DictReader(f))
266
+
267
+ n = min(8, len(top_features))
268
+ fig, axes = plt.subplots(2, 4, figsize=(16, 8))
269
+ axes = axes.flatten()
270
+
271
+ for i in range(n):
272
+ col = top_features[i]
273
+ ai_vals, hum_vals = [], []
274
+ for r in rows:
275
+ try:
276
+ v = float(r[col])
277
+ if np.isnan(v) or np.isinf(v): continue
278
+ (ai_vals if r["label_int"] == "1" else hum_vals).append(v)
279
+ except (ValueError, KeyError):
280
+ continue
281
+ ax = axes[i]
282
+ # histogram overlay
283
+ bins = 30
284
+ ax.hist(hum_vals, bins=bins, alpha=0.55, color=PALETTE["secondary"],
285
+ label=f"İnsan (n={len(hum_vals)})", density=True)
286
+ ax.hist(ai_vals, bins=bins, alpha=0.55, color=PALETTE["error"],
287
+ label=f"AI (n={len(ai_vals)})", density=True)
288
+ ax.set_title(col, fontsize=10, fontweight="bold")
289
+ ax.set_ylabel("Yoğunluk" if i % 4 == 0 else "")
290
+ ax.legend(fontsize=7, loc="best")
291
+ ax.tick_params(labelsize=8)
292
+
293
+ for i in range(n, len(axes)):
294
+ axes[i].axis("off")
295
+
296
+ fig.suptitle("AI vs İnsan — En Önemli 8 Özelliğin Dağılımı",
297
+ fontsize=14, fontweight="bold", y=1.02)
298
+ plt.tight_layout()
299
+ plt.savefig(FIGURES_DIR / "feature_distribution_ai_vs_human.png")
300
+ plt.close()
301
+ print(" ✓ feature_distribution_ai_vs_human.png")
302
+
303
+
304
+ def fig_shap_summary(model, scaler, feature_cols, X, max_display: int = 20) -> None:
305
+ """SHAP summary — global feature importance with directional info."""
306
+ try:
307
+ import shap
308
+ except ImportError:
309
+ print(" ! SHAP not available, skipping")
310
+ return
311
+
312
+ X_scaled = scaler.transform(X)
313
+ # Subsample for speed
314
+ if len(X_scaled) > 1000:
315
+ idx = np.random.RandomState(42).choice(len(X_scaled), 1000, replace=False)
316
+ X_sub = X_scaled[idx]
317
+ else:
318
+ X_sub = X_scaled
319
+
320
+ explainer = shap.TreeExplainer(model)
321
+ shap_values = explainer.shap_values(X_sub)
322
+
323
+ if isinstance(shap_values, list):
324
+ sv = shap_values[1] if len(shap_values) > 1 else shap_values[0]
325
+ else:
326
+ sv = shap_values
327
+
328
+ fig = plt.figure(figsize=(10, 8))
329
+ shap.summary_plot(
330
+ sv, X_sub,
331
+ feature_names=feature_cols,
332
+ max_display=max_display,
333
+ show=False,
334
+ plot_size=None,
335
+ )
336
+ plt.title("SHAP Özet Grafiği — Global Özellik Etkisi",
337
+ fontsize=13, fontweight="bold", pad=14)
338
+ plt.savefig(FIGURES_DIR / "shap_summary.png", bbox_inches="tight")
339
+ plt.close()
340
+ print(" ✓ shap_summary.png")
341
+
342
+
343
+ def fig_model_comparison(results: dict) -> None:
344
+ """Bar chart comparing accuracy/f1/auc across all models."""
345
+ items = [(k, v) for k, v in results.items() if not k.startswith("_") and isinstance(v, dict)]
346
+ items.sort(key=lambda x: x[1].get("roc_auc", 0), reverse=True)
347
+
348
+ names = [n for n, _ in items]
349
+ metrics = {
350
+ "Accuracy": [d["accuracy"] for _, d in items],
351
+ "F1 Score": [d["f1"] for _, d in items],
352
+ "ROC-AUC": [d["roc_auc"] for _, d in items],
353
+ "Precision": [d["precision"] for _, d in items],
354
+ "Recall": [d["recall"] for _, d in items],
355
+ }
356
+
357
+ x = np.arange(len(names))
358
+ width = 0.16
359
+ fig, ax = plt.subplots(figsize=(12, 6.5))
360
+ colors = [PALETTE["primary"], PALETTE["secondary"], PALETTE["error"],
361
+ PALETTE["accent"], "#7a5c3c"]
362
+
363
+ for i, (metric, vals) in enumerate(metrics.items()):
364
+ ax.bar(x + i * width - 2 * width, vals, width, label=metric,
365
+ color=colors[i], edgecolor=PALETTE["fg"], linewidth=0.3)
366
+
367
+ ax.set_ylabel("Skor / Score")
368
+ ax.set_title("Model Performans Karşılaştırması", fontsize=13, fontweight="bold")
369
+ ax.set_xticks(x)
370
+ ax.set_xticklabels(names, rotation=20, ha="right")
371
+ ax.legend(loc="lower right", framealpha=0.85)
372
+ ax.set_ylim([0.5, 1.0])
373
+ ax.grid(True, axis="y", alpha=0.4)
374
+
375
+ plt.savefig(FIGURES_DIR / "model_comparison_bars.png")
376
+ plt.close()
377
+ print(" ✓ model_comparison_bars.png")
378
+
379
+
380
+ def main() -> None:
381
+ FIGURES_DIR.mkdir(parents=True, exist_ok=True)
382
+ print(f"Output directory: {FIGURES_DIR}")
383
+ print("Loading artifacts...")
384
+ results, model, scaler, feature_cols = _load_artifacts()
385
+
386
+ importance = results.get("_feature_importance", {})
387
+ top_features = [n for n, _ in sorted(
388
+ importance.items(), key=lambda x: x[1], reverse=True,
389
+ )]
390
+
391
+ print("\nGenerating figures...")
392
+ fig_confusion_matrix(results)
393
+ fig_roc_comparison(results)
394
+ fig_pr_curves(results)
395
+ fig_feature_importance(results)
396
+ fig_calibration(results)
397
+ fig_model_comparison(results)
398
+ fig_feature_distributions(feature_cols, top_features)
399
+
400
+ print("\nLoading data for SHAP (this may take ~30s)...")
401
+ X, y = _load_csv_data(feature_cols)
402
+ fig_shap_summary(model, scaler, feature_cols, X)
403
+
404
+ print(f"\nDone. {len(list(FIGURES_DIR.glob('*.png')))} figures in {FIGURES_DIR}")
405
+
406
+
407
+ if __name__ == "__main__":
408
+ main()