Rthur2003 commited on
Commit
9b6c85d
·
1 Parent(s): 5ac76c0

fix: update evaluation to use optimal threshold and adjust BCEWithLogitsLoss for class imbalance

Browse files
app/training/train_deep_classifiers.py CHANGED
@@ -272,13 +272,15 @@ def evaluate_cv(
272
  print(f" Fold {fold+1}: AUC={auc:.4f}")
273
 
274
  elapsed = time.time() - t0
275
- y_pred = (all_probs > 0.5).astype(int)
 
276
  return {
277
  "accuracy": round(float(accuracy_score(y, y_pred)), 4),
278
  "precision": round(float(precision_score(y, y_pred, zero_division=0)), 4),
279
  "recall": round(float(recall_score(y, y_pred, zero_division=0)), 4),
280
  "f1": round(float(f1_score(y, y_pred, zero_division=0)), 4),
281
  "roc_auc": round(float(roc_auc_score(y, all_probs)), 4),
 
282
  "fold_aucs": [round(a, 4) for a in aucs],
283
  "train_time_sec": round(elapsed, 1),
284
  }
@@ -356,7 +358,10 @@ def train_final_model(
356
  n_features = X.shape[1]
357
  model = model_class(n_features).to(DEVICE)
358
  optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
359
- criterion = nn.BCEWithLogitsLoss()
 
 
 
360
  loader = DataLoader(
361
  TensorDataset(
362
  torch.tensor(X_tr, dtype=torch.float32),
 
272
  print(f" Fold {fold+1}: AUC={auc:.4f}")
273
 
274
  elapsed = time.time() - t0
275
+ threshold = _optimal_threshold(y, all_probs)
276
+ y_pred = (all_probs >= threshold).astype(int)
277
  return {
278
  "accuracy": round(float(accuracy_score(y, y_pred)), 4),
279
  "precision": round(float(precision_score(y, y_pred, zero_division=0)), 4),
280
  "recall": round(float(recall_score(y, y_pred, zero_division=0)), 4),
281
  "f1": round(float(f1_score(y, y_pred, zero_division=0)), 4),
282
  "roc_auc": round(float(roc_auc_score(y, all_probs)), 4),
283
+ "optimal_threshold": round(threshold, 4),
284
  "fold_aucs": [round(a, 4) for a in aucs],
285
  "train_time_sec": round(elapsed, 1),
286
  }
 
358
  n_features = X.shape[1]
359
  model = model_class(n_features).to(DEVICE)
360
  optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
361
+ n_pos = max(int(y_tr.sum()), 1)
362
+ n_neg = len(y_tr) - n_pos
363
+ pos_weight = torch.tensor([n_neg / n_pos], dtype=torch.float32).to(DEVICE)
364
+ criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
365
  loader = DataLoader(
366
  TensorDataset(
367
  torch.tensor(X_tr, dtype=torch.float32),