Rthur2003 commited on
Commit
bb6655d
·
1 Parent(s): b8d143b

feat: update model parameters and fix data leakage by removing duration_sec and sample_rate from features

Browse files
Files changed (1) hide show
  1. app/training/train_classifier.py +17 -8
app/training/train_classifier.py CHANGED
@@ -214,6 +214,7 @@ def train(
214
  json_results["_n_samples"] = len(y)
215
  json_results["_n_features"] = X.shape[1]
216
  json_results["_n_folds"] = n_folds
 
217
  if importance_data:
218
  json_results["_feature_importance"] = {
219
  name: round(imp, 6) for name, imp in importance_data
@@ -252,9 +253,10 @@ def _build_candidates() -> list[tuple[str, Any]]:
252
  (
253
  "Random Forest",
254
  RandomForestClassifier(
255
- n_estimators=300,
256
- max_depth=20,
257
- min_samples_leaf=5,
 
258
  class_weight="balanced",
259
  random_state=42,
260
  n_jobs=-1,
@@ -301,11 +303,15 @@ def _build_candidates() -> list[tuple[str, Any]]:
301
  candidates.append((
302
  "XGBoost",
303
  xgb.XGBClassifier(
304
- n_estimators=300,
305
- max_depth=8,
306
  learning_rate=0.05,
307
  subsample=0.8,
308
  colsample_bytree=0.8,
 
 
 
 
309
  scale_pos_weight=1.0,
310
  eval_metric="logloss",
311
  random_state=42,
@@ -317,12 +323,15 @@ def _build_candidates() -> list[tuple[str, Any]]:
317
  candidates.append((
318
  "LightGBM",
319
  lgb.LGBMClassifier(
320
- n_estimators=300,
321
- max_depth=8,
322
  learning_rate=0.05,
323
- num_leaves=31,
324
  subsample=0.8,
325
  colsample_bytree=0.8,
 
 
 
326
  class_weight="balanced",
327
  random_state=42,
328
  verbose=-1,
 
214
  json_results["_n_samples"] = len(y)
215
  json_results["_n_features"] = X.shape[1]
216
  json_results["_n_folds"] = n_folds
217
+ json_results["_data_leakage_fix"] = "duration_sec and sample_rate removed from features (v2)"
218
  if importance_data:
219
  json_results["_feature_importance"] = {
220
  name: round(imp, 6) for name, imp in importance_data
 
253
  (
254
  "Random Forest",
255
  RandomForestClassifier(
256
+ n_estimators=200,
257
+ max_depth=12,
258
+ min_samples_leaf=8,
259
+ min_samples_split=10,
260
  class_weight="balanced",
261
  random_state=42,
262
  n_jobs=-1,
 
303
  candidates.append((
304
  "XGBoost",
305
  xgb.XGBClassifier(
306
+ n_estimators=200,
307
+ max_depth=5,
308
  learning_rate=0.05,
309
  subsample=0.8,
310
  colsample_bytree=0.8,
311
+ min_child_weight=5,
312
+ reg_alpha=0.1,
313
+ reg_lambda=1.0,
314
+ gamma=0.1,
315
  scale_pos_weight=1.0,
316
  eval_metric="logloss",
317
  random_state=42,
 
323
  candidates.append((
324
  "LightGBM",
325
  lgb.LGBMClassifier(
326
+ n_estimators=200,
327
+ max_depth=5,
328
  learning_rate=0.05,
329
+ num_leaves=24,
330
  subsample=0.8,
331
  colsample_bytree=0.8,
332
+ min_child_weight=5,
333
+ reg_alpha=0.1,
334
+ reg_lambda=1.0,
335
  class_weight="balanced",
336
  random_state=42,
337
  verbose=-1,