Rthur2003 commited on
Commit
337d9ae
·
1 Parent(s): b74a8cb

feat: add AURIS classifier training module with model evaluation and feature importance

Browse files
Files changed (1) hide show
  1. app/training/train_classifier.py +237 -0
app/training/train_classifier.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train AURIS classifier on extracted audio features.
3
+
4
+ Increment 1: RandomForest / GradientBoosting on librosa + vocal features.
5
+ This replaces the heuristic scoring with a data-driven classifier.
6
+
7
+ Usage:
8
+ python -m app.training.train_classifier data/sonics/features.csv
9
+
10
+ Outputs:
11
+ models/auris_classifier_v1.pkl — trained model
12
+ models/feature_scaler_v1.pkl — fitted StandardScaler
13
+ models/feature_columns_v1.json — ordered feature column names
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import csv
19
+ import json
20
+ import pickle
21
+ import sys
22
+ from pathlib import Path
23
+
24
+ import numpy as np
25
+
26
+ from sklearn.ensemble import (
27
+ GradientBoostingClassifier,
28
+ RandomForestClassifier,
29
+ )
30
+ from sklearn.model_selection import (
31
+ StratifiedKFold,
32
+ cross_val_predict,
33
+ )
34
+ from sklearn.preprocessing import StandardScaler
35
+ from sklearn.metrics import (
36
+ accuracy_score,
37
+ f1_score,
38
+ roc_auc_score,
39
+ )
40
+
41
+ # Optional: LightGBM for better performance
42
+ try:
43
+ import lightgbm as lgb
44
+ HAS_LGBM = True
45
+ except ImportError:
46
+ HAS_LGBM = False
47
+
48
+ sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
49
+ from app.training.evaluate import (
50
+ load_features_csv,
51
+ evaluate_predictions,
52
+ )
53
+
54
+
55
+ def train(
56
+ features_csv: str | Path,
57
+ models_dir: str | Path = "models",
58
+ n_folds: int = 5,
59
+ ) -> dict:
60
+ """
61
+ Train and evaluate classifier on extracted features.
62
+
63
+ Uses 5-fold cross-validation to estimate real accuracy,
64
+ then trains final model on all data.
65
+
66
+ Returns:
67
+ Dict with metrics and model paths.
68
+ """
69
+ models_dir = Path(models_dir)
70
+ models_dir.mkdir(parents=True, exist_ok=True)
71
+
72
+ # ── Load data ──────────────────────────────────
73
+ X, y = load_features_csv(features_csv)
74
+
75
+ # Get feature column names
76
+ with open(features_csv, "r", encoding="utf-8") as f:
77
+ reader = csv.DictReader(f)
78
+ feature_cols = [
79
+ c for c in reader.fieldnames
80
+ if c not in ("file_path", "label_int")
81
+ ]
82
+
83
+ # ── Handle NaN/Inf ─────────────────────────────
84
+ X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=-1.0)
85
+
86
+ # ── Scale features ─────────────────────────────
87
+ scaler = StandardScaler()
88
+ X_scaled = scaler.fit_transform(X)
89
+
90
+ # ── Train multiple models, pick best ───────────
91
+ candidates = _build_candidates()
92
+ best_model = None
93
+ best_name = ""
94
+ best_auc = 0.0
95
+ results = {}
96
+
97
+ cv = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
98
+
99
+ for name, model in candidates:
100
+ print(f"\n{'─' * 40}")
101
+ print(f"Training: {name}")
102
+ print(f"{'─' * 40}")
103
+
104
+ # Cross-validated predictions
105
+ y_prob = cross_val_predict(
106
+ model, X_scaled, y,
107
+ cv=cv, method="predict_proba",
108
+ )[:, 1]
109
+ y_pred = (y_prob > 0.5).astype(int)
110
+
111
+ acc = accuracy_score(y, y_pred)
112
+ f1 = f1_score(y, y_pred)
113
+ auc = roc_auc_score(y, y_prob)
114
+
115
+ print(f" CV Accuracy: {acc:.4f}")
116
+ print(f" CV F1: {f1:.4f}")
117
+ print(f" CV ROC-AUC: {auc:.4f}")
118
+
119
+ results[name] = {
120
+ "accuracy": round(acc, 4),
121
+ "f1": round(f1, 4),
122
+ "roc_auc": round(auc, 4),
123
+ }
124
+
125
+ if auc > best_auc:
126
+ best_auc = auc
127
+ best_name = name
128
+ best_model = model
129
+
130
+ # ── Final evaluation of best model ─────────────
131
+ print(f"\n{'=' * 50}")
132
+ print(f" Best model: {best_name} (AUC={best_auc:.4f})")
133
+ print(f"{'=' * 50}")
134
+
135
+ # Cross-val predictions for detailed report
136
+ y_prob_best = cross_val_predict(
137
+ best_model, X_scaled, y,
138
+ cv=cv, method="predict_proba",
139
+ )[:, 1]
140
+ y_pred_best = (y_prob_best > 0.5).astype(int)
141
+
142
+ evaluate_predictions(
143
+ y, y_pred_best, y_prob_best,
144
+ title=f"Best: {best_name}",
145
+ )
146
+
147
+ # ── Train final model on ALL data ──────────────
148
+ print(f"\nTraining final {best_name} on all data...")
149
+ best_model.fit(X_scaled, y)
150
+
151
+ # ── Feature importance ─────────────────────────
152
+ if hasattr(best_model, "feature_importances_"):
153
+ importances = best_model.feature_importances_
154
+ top_features = sorted(
155
+ zip(feature_cols, importances),
156
+ key=lambda x: x[1],
157
+ reverse=True,
158
+ )
159
+ print("\nTop 10 features:")
160
+ for fname, imp in top_features[:10]:
161
+ bar = "█" * int(imp * 100)
162
+ print(f" {fname:<30} {imp:.4f} {bar}")
163
+
164
+ # ── Save artifacts ─────────────────────────────
165
+ model_path = models_dir / "auris_classifier_v1.pkl"
166
+ scaler_path = models_dir / "feature_scaler_v1.pkl"
167
+ columns_path = models_dir / "feature_columns_v1.json"
168
+
169
+ with open(model_path, "wb") as f:
170
+ pickle.dump(best_model, f)
171
+ with open(scaler_path, "wb") as f:
172
+ pickle.dump(scaler, f)
173
+ with open(columns_path, "w") as f:
174
+ json.dump(feature_cols, f, indent=2)
175
+
176
+ print(f"\nSaved:")
177
+ print(f" Model: {model_path}")
178
+ print(f" Scaler: {scaler_path}")
179
+ print(f" Columns: {columns_path}")
180
+
181
+ return {
182
+ "best_model": best_name,
183
+ "best_auc": best_auc,
184
+ "results": results,
185
+ "model_path": str(model_path),
186
+ }
187
+
188
+
189
+ def _build_candidates() -> list[tuple[str, object]]:
190
+ """Build list of classifier candidates to evaluate."""
191
+ candidates = [
192
+ (
193
+ "RandomForest",
194
+ RandomForestClassifier(
195
+ n_estimators=300,
196
+ max_depth=20,
197
+ min_samples_leaf=5,
198
+ class_weight="balanced",
199
+ random_state=42,
200
+ n_jobs=-1,
201
+ ),
202
+ ),
203
+ (
204
+ "GradientBoosting",
205
+ GradientBoostingClassifier(
206
+ n_estimators=200,
207
+ max_depth=6,
208
+ learning_rate=0.1,
209
+ subsample=0.8,
210
+ random_state=42,
211
+ ),
212
+ ),
213
+ ]
214
+
215
+ if HAS_LGBM:
216
+ candidates.append((
217
+ "LightGBM",
218
+ lgb.LGBMClassifier(
219
+ n_estimators=300,
220
+ max_depth=8,
221
+ learning_rate=0.05,
222
+ num_leaves=31,
223
+ subsample=0.8,
224
+ colsample_bytree=0.8,
225
+ class_weight="balanced",
226
+ random_state=42,
227
+ verbose=-1,
228
+ ),
229
+ ))
230
+
231
+ return candidates
232
+
233
+
234
+ if __name__ == "__main__":
235
+ csv_path = sys.argv[1] if len(sys.argv) > 1 else "data/sonics/features.csv"
236
+ model_dir = sys.argv[2] if len(sys.argv) > 2 else "models"
237
+ train(csv_path, model_dir)