Rthur2003 commited on
Commit
5ac76c0
·
1 Parent(s): bfbdc0b

fix: adjust BCEWithLogitsLoss to handle class imbalance with pos_weight

Browse files
app/training/train_deep_classifiers.py CHANGED
@@ -209,7 +209,11 @@ def train_one_fold(
209
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
210
  optimizer, mode="max", factor=0.5, patience=5
211
  )
212
- criterion = nn.BCEWithLogitsLoss()
 
 
 
 
213
 
214
  best_auc = 0.0
215
  best_probs = None
 
209
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
210
  optimizer, mode="max", factor=0.5, patience=5
211
  )
212
+ # pos_weight compensates for class imbalance (n_neg / n_pos)
213
+ n_pos = max(int(y_train.sum()), 1)
214
+ n_neg = len(y_train) - n_pos
215
+ pos_weight = torch.tensor([n_neg / n_pos], dtype=torch.float32).to(DEVICE)
216
+ criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
217
 
218
  best_auc = 0.0
219
  best_probs = None