| import torch |
| import torch.nn.functional as F |
| from tqdm import tqdm |
|
|
| class Trainer: |
| def __init__(self, ae, flow, cfg, loader, pad_id, stop_id): |
| self.ae = ae.to(cfg.device) |
| self.flow = flow.to(cfg.device) if flow else None |
| self.cfg = cfg |
| self.loader = loader |
| self.device = cfg.device |
| self.pad_id = pad_id |
| self.stop_id = stop_id |
|
|
| def train_ae(self, optimizer): |
| self.ae.train() |
| total_loss = 0 |
| pbar = tqdm(self.loader, desc="Train AE") |
| optimizer.zero_grad() |
| |
| for step, batch in enumerate(pbar): |
| tgt = batch['tgt_ids'].to(self.device) |
| mask = batch['tgt_mask'].to(self.device) |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| logits, z = self.ae(tgt, mask) |
|
|
| V = logits.size(-1) |
| B, L = tgt.shape3 |
|
|
| |
| labels_tok = tgt.masked_fill(mask == 0, -100) |
| loss_tok = F.cross_entropy( |
| logits.view(-1, V), |
| labels_tok.view(-1), |
| ignore_index=-100, |
| reduction="mean" |
| ) |
|
|
| |
| pad_pos = (mask == 0) |
| if pad_pos.any(): |
| |
| ce_all = F.cross_entropy( |
| logits.view(-1, V), |
| tgt.new_full((B * L,), self.pad_id), |
| reduction="none" |
| ).view(B, L) |
| loss_pad = (ce_all * pad_pos).sum() / (pad_pos.sum() + 1e-6) |
| else: |
| loss_pad = logits.new_tensor(0.0) |
|
|
| |
| stop_pos = ((tgt == self.stop_id) & (mask == 1)) |
| if stop_pos.any(): |
| ce_tok = F.cross_entropy( |
| logits.view(-1, V), |
| tgt.view(-1), |
| reduction="none" |
| ).view(B, L) |
| loss_stop = (ce_tok * stop_pos).sum() / (stop_pos.sum() + 1e-6) |
| else: |
| loss_stop = logits.new_tensor(0.0) |
|
|
| |
| lambda_pad = 0.1 |
| lambda_stop = 0.2 |
| loss = loss_tok + lambda_pad * loss_pad + lambda_stop * loss_stop |
|
|
| loss = loss / self.cfg.grad_accum_steps |
| loss.backward() |
| |
| if (step + 1) % self.cfg.grad_accum_steps == 0: |
| optimizer.step() |
| optimizer.zero_grad() |
| |
| total_loss += loss.item() * self.cfg.grad_accum_steps |
| pbar.set_postfix(loss=loss.item() * self.cfg.grad_accum_steps) |
| |
| return total_loss / len(self.loader) |
|
|
| def train_robust_ae(self, optimizer): |
| |
| self.ae.train() |
| total_loss = 0 |
| noise_std = 0.05 |
|
|
| for batch in self.loader: |
| tgt_ids = batch['tgt_ids'].to(self.device) |
| tgt_mask = batch['tgt_mask'].to(self.device) |
| |
| |
| with torch.no_grad(): |
| z_clean = self.ae.encode(tgt_ids, tgt_mask) |
| |
| |
| |
| noise = torch.randn_like(z_clean) * noise_std |
| z_noisy = z_clean + noise |
| |
| |
| logits = self.ae.decode(z_noisy, attention_mask=tgt_mask) |
|
|
| |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), |
| tgt_ids.view(-1), |
| reduction='none') |
| loss = (loss * tgt_mask.view(-1)).sum() / tgt_mask.sum() |
|
|
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| total_loss += loss.item() |
| |
| return total_loss / len(self.loader) |
|
|
|
|
| def train_ae_combined(self, optimizer, epoch, max_epochs): |
| """ |
| 结合了 基础重建 + 课程去噪 + Pad/Stop 优化 |
| """ |
| self.ae.train() |
| total_loss = 0 |
| |
| |
| |
| |
| if epoch < max_epochs * 0.2: |
| current_noise = 0.0 |
| else: |
| progress = (epoch - max_epochs * 0.2) / (max_epochs * 0.8) |
| current_noise = 0.1 * progress |
|
|
| pbar = tqdm(self.loader, desc=f"Train AE (Noise={current_noise:.4f})") |
| |
| for step, batch in enumerate(pbar): |
| tgt = batch['tgt_ids'].to(self.device) |
| mask = batch['tgt_mask'].to(self.device) |
| |
| |
| with torch.no_grad(): |
| z_clean = self.ae.encode(tgt, mask) |
| |
| |
| if current_noise > 0: |
| noise = torch.randn_like(z_clean) * current_noise |
| z_input = z_clean + noise |
| else: |
| z_input = z_clean |
| |
| |
| logits = self.ae.decode(z_input, attention_mask=mask) |
| |
| |
| V = logits.size(-1) |
| B, L = tgt.shape |
|
|
| |
| labels_tok = tgt.masked_fill(mask == 0, -100) |
| loss_tok = F.cross_entropy(logits.view(-1, V), labels_tok.view(-1), ignore_index=-100) |
|
|
| |
| pad_pos = (mask == 0) |
| if pad_pos.any(): |
| ce_pad = F.cross_entropy(logits.view(-1, V), tgt.new_full((B*L,), self.pad_id), reduction='none').view(B,L) |
| loss_pad = (ce_pad * pad_pos).sum() / (pad_pos.sum() + 1e-6) |
| else: |
| loss_pad = torch.tensor(0.0, device=self.device) |
| |
| |
| stop_pos = ((tgt == self.stop_id) & (mask == 1)) |
| if stop_pos.any(): |
| ce_stop = F.cross_entropy(logits.view(-1, V), tgt.view(-1), reduction='none').view(B,L) |
| loss_stop = (ce_stop * stop_pos).sum() / (stop_pos.sum() + 1e-6) |
| else: |
| loss_stop = torch.tensor(0.0, device=self.device) |
|
|
| |
| loss = loss_tok + 0.1 * loss_pad + 0.5 * loss_stop |
| |
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| total_loss += loss.item() |
| pbar.set_postfix(loss=loss.item()) |
| |
| return total_loss / len(self.loader) |
|
|
| def train_flow(self, optimizer): |
| self.flow.train() |
| self.ae.eval() |
| total_loss = 0 |
| pbar = tqdm(self.loader, desc="Train Flow") |
| optimizer.zero_grad() |
|
|
| scale = getattr(self.ae, "latent_scale", 10.0) |
| |
| for step, batch in enumerate(pbar): |
| src = batch['src_ids'].to(self.device) |
| src_mask = batch['src_mask'].to(self.device) |
| tgt = batch['tgt_ids'].to(self.device) |
| tgt_mask = batch['tgt_mask'].to(self.device) |
| |
| with torch.no_grad(): |
| z_bad = self.ae.encode(src, src_mask) |
| z_good = self.ae.encode(tgt, tgt_mask) |
| |
| |
| bs = z_bad.shape[0] |
| t = torch.rand(bs, device=self.device).view(bs, 1, 1) |
| |
| |
| |
| z_t_linear = (1 - t) * z_bad + t * z_good |
| |
| |
| z_t = z_t_linear |
|
|
| |
| |
| |
| |
| |
| |
| pred_z1 = self.flow(z_t, t, condition=z_bad) |
| |
| pred_z1 = pred_z1 |
| |
| |
| |
| mse = (pred_z1 - z_good).pow(2).mean(dim=-1) |
| w = tgt_mask.float() |
|
|
| |
| stop_pos = ((tgt == self.stop_id) & (tgt_mask == 1)) |
| w = w + stop_pos.float() * 2.0 |
|
|
| loss = (mse * w).sum() / (w.sum() + 1e-6) |
|
|
| |
| |
|
|
| loss = loss / self.cfg.grad_accum_steps |
| loss.backward() |
| |
| if (step + 1) % self.cfg.grad_accum_steps == 0: |
| optimizer.step() |
| optimizer.zero_grad() |
| |
| total_loss += loss.item() * self.cfg.grad_accum_steps |
| pbar.set_postfix(loss=loss.item() * self.cfg.grad_accum_steps) |
| |
| return total_loss / len(self.loader) |