Spaces:
Sleeping
Sleeping
fix: adjust BCEWithLogitsLoss to handle class imbalance with pos_weight
Browse files
app/training/train_deep_classifiers.py
CHANGED
|
@@ -209,7 +209,11 @@ def train_one_fold(
|
|
| 209 |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 210 |
optimizer, mode="max", factor=0.5, patience=5
|
| 211 |
)
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
best_auc = 0.0
|
| 215 |
best_probs = None
|
|
|
|
| 209 |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 210 |
optimizer, mode="max", factor=0.5, patience=5
|
| 211 |
)
|
| 212 |
+
# pos_weight compensates for class imbalance (n_neg / n_pos)
|
| 213 |
+
n_pos = max(int(y_train.sum()), 1)
|
| 214 |
+
n_neg = len(y_train) - n_pos
|
| 215 |
+
pos_weight = torch.tensor([n_neg / n_pos], dtype=torch.float32).to(DEVICE)
|
| 216 |
+
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
| 217 |
|
| 218 |
best_auc = 0.0
|
| 219 |
best_probs = None
|