Rthur2003 commited on
Commit
57f19bf
·
1 Parent(s): 058eadc

feat: enhance candidate selection with class ratio and calibrated SVC for improved model performance

Browse files
Files changed (1) hide show
  1. app/training/train_classifier.py +25 -36
app/training/train_classifier.py CHANGED
@@ -323,7 +323,7 @@ def _select_best_candidates(
323
  selected: list[tuple[str, Any]] = []
324
  tuning_results: dict[str, dict[str, Any]] = {}
325
 
326
- for name, variants in _build_candidate_families().items():
327
  print("\n" + "." * 56)
328
  print(f"Selecting hyperparameters for: {name}")
329
  print("." * 56)
@@ -362,8 +362,14 @@ def _select_best_candidates(
362
  return selected, tuning_results
363
 
364
 
365
- def _build_candidate_families() -> dict[str, list[Any]]:
366
- families: dict[str, list[Any]] = {
 
 
 
 
 
 
367
  "Logistic Regression": [
368
  LogisticRegression(
369
  C=value,
@@ -435,38 +441,11 @@ def _build_candidate_families() -> dict[str, list[Any]]:
435
  ),
436
  ],
437
  "SVM (RBF)": [
438
- SVC(
439
- kernel="rbf",
440
- C=1.0,
441
- gamma="scale",
442
- class_weight="balanced",
443
- probability=True,
444
- random_state=42,
445
- ),
446
- SVC(
447
- kernel="rbf",
448
- C=3.0,
449
- gamma="scale",
450
- class_weight="balanced",
451
- probability=True,
452
- random_state=42,
453
- ),
454
- SVC(
455
- kernel="rbf",
456
- C=6.0,
457
- gamma=0.02,
458
- class_weight="balanced",
459
- probability=True,
460
- random_state=42,
461
- ),
462
- SVC(
463
- kernel="rbf",
464
- C=10.0,
465
- gamma=0.05,
466
- class_weight="balanced",
467
- probability=True,
468
- random_state=42,
469
- ),
470
  ],
471
  "MLP Neural Network": [
472
  MLPClassifier(
@@ -506,6 +485,7 @@ def _build_candidate_families() -> dict[str, list[Any]]:
506
  }
507
 
508
  if HAS_XGB:
 
509
  families["XGBoost"] = [
510
  xgb.XGBClassifier(
511
  n_estimators=300,
@@ -517,6 +497,7 @@ def _build_candidate_families() -> dict[str, list[Any]]:
517
  reg_alpha=0.2,
518
  reg_lambda=1.2,
519
  gamma=0.1,
 
520
  eval_metric="logloss",
521
  tree_method="hist",
522
  random_state=42,
@@ -533,6 +514,7 @@ def _build_candidate_families() -> dict[str, list[Any]]:
533
  reg_alpha=0.1,
534
  reg_lambda=1.0,
535
  gamma=0.0,
 
536
  eval_metric="logloss",
537
  tree_method="hist",
538
  random_state=42,
@@ -549,6 +531,7 @@ def _build_candidate_families() -> dict[str, list[Any]]:
549
  reg_alpha=0.4,
550
  reg_lambda=1.5,
551
  gamma=0.2,
 
552
  eval_metric="logloss",
553
  tree_method="hist",
554
  random_state=42,
@@ -635,7 +618,13 @@ def _safe_model_name(name: str) -> str:
635
  def _summarize_selected_params(name: str, model: Any) -> dict[str, Any]:
636
  tuned_keys = _TUNED_PARAM_KEYS.get(name, ())
637
  params = model.get_params()
638
- return {key: params[key] for key in tuned_keys if key in params}
 
 
 
 
 
 
639
 
640
 
641
  def _extract_importance(
 
323
  selected: list[tuple[str, Any]] = []
324
  tuning_results: dict[str, dict[str, Any]] = {}
325
 
326
+ for name, variants in _build_candidate_families(y_train).items():
327
  print("\n" + "." * 56)
328
  print(f"Selecting hyperparameters for: {name}")
329
  print("." * 56)
 
362
  return selected, tuning_results
363
 
364
 
365
+ def _class_ratio(y: np.ndarray) -> float:
366
+ """Returns n_negative / n_positive for scale_pos_weight in XGBoost."""
367
+ n_pos = int(np.sum(y == 1))
368
+ n_neg = int(np.sum(y == 0))
369
+ return n_neg / n_pos if n_pos > 0 else 1.0
370
+
371
+
372
+ def _build_candidate_families(y: np.ndarray) -> dict[str, list[Any]]:
373
  "Logistic Regression": [
374
  LogisticRegression(
375
  C=value,
 
441
  ),
442
  ],
443
  "SVM (RBF)": [
444
+ CalibratedClassifierCV(
445
+ SVC(kernel="rbf", C=c, gamma=g, class_weight="balanced", random_state=42),
446
+ method="isotonic", cv=3,
447
+ )
448
+ for c, g in ((1.0, "scale"), (3.0, "scale"), (6.0, 0.02), (10.0, 0.05))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
  ],
450
  "MLP Neural Network": [
451
  MLPClassifier(
 
485
  }
486
 
487
  if HAS_XGB:
488
+ _spw = _class_ratio(y)
489
  families["XGBoost"] = [
490
  xgb.XGBClassifier(
491
  n_estimators=300,
 
497
  reg_alpha=0.2,
498
  reg_lambda=1.2,
499
  gamma=0.1,
500
+ scale_pos_weight=_spw,
501
  eval_metric="logloss",
502
  tree_method="hist",
503
  random_state=42,
 
514
  reg_alpha=0.1,
515
  reg_lambda=1.0,
516
  gamma=0.0,
517
+ scale_pos_weight=_spw,
518
  eval_metric="logloss",
519
  tree_method="hist",
520
  random_state=42,
 
531
  reg_alpha=0.4,
532
  reg_lambda=1.5,
533
  gamma=0.2,
534
+ scale_pos_weight=_spw,
535
  eval_metric="logloss",
536
  tree_method="hist",
537
  random_state=42,
 
618
  def _summarize_selected_params(name: str, model: Any) -> dict[str, Any]:
619
  tuned_keys = _TUNED_PARAM_KEYS.get(name, ())
620
  params = model.get_params()
621
+ # CalibratedClassifierCV nests params as "estimator__<key>"
622
+ flat: dict[str, Any] = {}
623
+ for key, value in params.items():
624
+ flat_key = key.split("__")[-1]
625
+ if flat_key not in flat:
626
+ flat[flat_key] = value
627
+ return {key: flat[key] for key in tuned_keys if key in flat}
628
 
629
 
630
  def _extract_importance(