Rthur2003 commited on
Commit
076d979
·
1 Parent(s): 8e5e154

feat: implement mixed precision training and gradient accumulation in training loop

Browse files
Files changed (1) hide show
  1. app/training/wav2vec2_classifier.py +29 -10
app/training/wav2vec2_classifier.py CHANGED
@@ -241,26 +241,40 @@ def train_wav2vec2(
241
 
242
  criterion = nn.BCEWithLogitsLoss()
243
 
244
- # Training loop
245
  best_val_auc = 0.0
246
  patience_counter = 0
247
  history = []
 
 
248
 
249
  for epoch in range(config.epochs):
250
  model.train()
251
  train_loss = 0.0
 
252
 
253
- for batch_audio, batch_labels in train_loader:
254
  batch_audio = batch_audio.to(device)
255
  batch_labels = batch_labels.to(device)
256
 
257
- optimizer.zero_grad()
258
- logits, _ = model(batch_audio)
259
- loss = criterion(logits.squeeze(-1), batch_labels)
260
- loss.backward()
261
- optimizer.step()
 
 
 
 
 
 
 
 
 
 
 
262
 
263
- train_loss += loss.item()
264
 
265
  avg_train_loss = train_loss / len(train_loader)
266
 
@@ -272,7 +286,11 @@ def train_wav2vec2(
272
  with torch.no_grad():
273
  for batch_audio, batch_labels in val_loader:
274
  batch_audio = batch_audio.to(device)
275
- logits, _ = model(batch_audio)
 
 
 
 
276
  probs = torch.sigmoid(logits.squeeze(-1))
277
  val_probs.extend(probs.cpu().numpy())
278
  val_labels.extend(batch_labels.numpy())
@@ -289,7 +307,8 @@ def train_wav2vec2(
289
  f"Epoch {epoch + 1}/{config.epochs} | "
290
  f"Loss: {avg_train_loss:.4f} | "
291
  f"Val Acc: {val_acc:.4f} | "
292
- f"Val AUC: {val_auc:.4f}"
 
293
  )
294
 
295
  history.append({
 
241
 
242
  criterion = nn.BCEWithLogitsLoss()
243
 
244
+ # Training loop with mixed precision + gradient accumulation
245
  best_val_auc = 0.0
246
  patience_counter = 0
247
  history = []
248
+ scaler_amp = torch.amp.GradScaler("cuda") if device.type == "cuda" else None
249
+ accum_steps = 4 # effective batch = batch_size * accum_steps
250
 
251
  for epoch in range(config.epochs):
252
  model.train()
253
  train_loss = 0.0
254
+ optimizer.zero_grad()
255
 
256
+ for step, (batch_audio, batch_labels) in enumerate(train_loader):
257
  batch_audio = batch_audio.to(device)
258
  batch_labels = batch_labels.to(device)
259
 
260
+ if scaler_amp is not None:
261
+ with torch.amp.autocast("cuda"):
262
+ logits, _ = model(batch_audio)
263
+ loss = criterion(logits.squeeze(-1), batch_labels) / accum_steps
264
+ scaler_amp.scale(loss).backward()
265
+ if (step + 1) % accum_steps == 0 or (step + 1) == len(train_loader):
266
+ scaler_amp.step(optimizer)
267
+ scaler_amp.update()
268
+ optimizer.zero_grad()
269
+ else:
270
+ logits, _ = model(batch_audio)
271
+ loss = criterion(logits.squeeze(-1), batch_labels) / accum_steps
272
+ loss.backward()
273
+ if (step + 1) % accum_steps == 0 or (step + 1) == len(train_loader):
274
+ optimizer.step()
275
+ optimizer.zero_grad()
276
 
277
+ train_loss += loss.item() * accum_steps
278
 
279
  avg_train_loss = train_loss / len(train_loader)
280
 
 
286
  with torch.no_grad():
287
  for batch_audio, batch_labels in val_loader:
288
  batch_audio = batch_audio.to(device)
289
+ if scaler_amp is not None:
290
+ with torch.amp.autocast("cuda"):
291
+ logits, _ = model(batch_audio)
292
+ else:
293
+ logits, _ = model(batch_audio)
294
  probs = torch.sigmoid(logits.squeeze(-1))
295
  val_probs.extend(probs.cpu().numpy())
296
  val_labels.extend(batch_labels.numpy())
 
307
  f"Epoch {epoch + 1}/{config.epochs} | "
308
  f"Loss: {avg_train_loss:.4f} | "
309
  f"Val Acc: {val_acc:.4f} | "
310
+ f"Val AUC: {val_auc:.4f}",
311
+ flush=True,
312
  )
313
 
314
  history.append({