Diff-Refine / src /trainer.py
2ira's picture
Add files using upload-large-folder tool
77d636f verified
import torch
import torch.nn.functional as F
from tqdm import tqdm
class Trainer:
def __init__(self, ae, flow, cfg, loader, pad_id, stop_id):
self.ae = ae.to(cfg.device)
self.flow = flow.to(cfg.device) if flow else None
self.cfg = cfg
self.loader = loader
self.device = cfg.device
self.pad_id = pad_id
self.stop_id = stop_id
def train_ae(self, optimizer):
self.ae.train()
total_loss = 0
pbar = tqdm(self.loader, desc="Train AE")
optimizer.zero_grad()
for step, batch in enumerate(pbar):
tgt = batch['tgt_ids'].to(self.device)
mask = batch['tgt_mask'].to(self.device)
# logits, z = self.ae(tgt, mask)
# ## 不太明白这里的mask 机制
# labels = tgt.masked_fill(mask == 0, -100)
# loss = F.cross_entropy(
# logits.view(-1, logits.size(-1)),
# labels.view(-1),
# ignore_index=-100
# )
# Reconstruction Loss
# loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt.view(-1), ignore_index=1)
logits, z = self.ae(tgt, mask) # decoder_mask 默认 = mask
V = logits.size(-1)
B, L = tgt.shape3
# 1) token loss:只看 mask==1
labels_tok = tgt.masked_fill(mask == 0, -100)
loss_tok = F.cross_entropy(
logits.view(-1, V),
labels_tok.view(-1),
ignore_index=-100,
reduction="mean"
)
# 2) pad loss:mask==0 的位置强制预测 PAD(轻权重)
pad_pos = (mask == 0)
if pad_pos.any():
# 每个位置的 CE
ce_all = F.cross_entropy(
logits.view(-1, V),
tgt.new_full((B * L,), self.pad_id),
reduction="none"
).view(B, L)
loss_pad = (ce_all * pad_pos).sum() / (pad_pos.sum() + 1e-6)
else:
loss_pad = logits.new_tensor(0.0)
# 3) 可选:stop 位置加权(让 SEP 更稳)
stop_pos = ((tgt == self.stop_id) & (mask == 1))
if stop_pos.any():
ce_tok = F.cross_entropy(
logits.view(-1, V),
tgt.view(-1),
reduction="none"
).view(B, L)
loss_stop = (ce_tok * stop_pos).sum() / (stop_pos.sum() + 1e-6)
else:
loss_stop = logits.new_tensor(0.0)
# 合成:pad/stop 的权重别太大
lambda_pad = 0.1
lambda_stop = 0.2
loss = loss_tok + lambda_pad * loss_pad + lambda_stop * loss_stop
loss = loss / self.cfg.grad_accum_steps
loss.backward()
if (step + 1) % self.cfg.grad_accum_steps == 0:
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item() * self.cfg.grad_accum_steps
pbar.set_postfix(loss=loss.item() * self.cfg.grad_accum_steps)
return total_loss / len(self.loader)
def train_robust_ae(self, optimizer):
self.ae.train()
total_loss = 0
noise_std = 0.05
for batch in self.loader:
tgt_ids = batch['tgt_ids'].to(self.device)
tgt_mask = batch['tgt_mask'].to(self.device)
# 1. get normal z
with torch.no_grad():
z_clean = self.ae.encode(tgt_ids, tgt_mask)
# 2. add noise (Denoising Training)
# Decoder -> like z
noise = torch.randn_like(z_clean) * noise_std
z_noisy = z_clean + noise
# 3. Decode
logits = self.ae.decode(z_noisy, attention_mask=tgt_mask)
# 4. Loss
loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
tgt_ids.view(-1),
reduction='none')
loss = (loss * tgt_mask.view(-1)).sum() / tgt_mask.sum()
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(self.loader)
def train_ae_combined(self, optimizer, epoch, max_epochs):
"""
结合了 基础重建 + 课程去噪 + Pad/Stop 优化
"""
self.ae.train()
total_loss = 0
# --- 课程噪声调度 (Curriculum Noise) ---
# 前 20% 的 Epoch 不加噪声,先学好重建
# 后面线性增加到 0.1
if epoch < max_epochs * 0.2:
current_noise = 0.0
else:
progress = (epoch - max_epochs * 0.2) / (max_epochs * 0.8)
current_noise = 0.1 * progress # 最大噪声 0.1
pbar = tqdm(self.loader, desc=f"Train AE (Noise={current_noise:.4f})")
for step, batch in enumerate(pbar):
tgt = batch['tgt_ids'].to(self.device)
mask = batch['tgt_mask'].to(self.device)
# 1. Encode Clean
with torch.no_grad():
z_clean = self.ae.encode(tgt, mask)
# 2. Add Noise (如果 noise > 0)
if current_noise > 0:
noise = torch.randn_like(z_clean) * current_noise
z_input = z_clean + noise
else:
z_input = z_clean
# 3. Decode
logits = self.ae.decode(z_input, attention_mask=mask)
# 4. Calculate Advanced Loss (Copy from your original code)
V = logits.size(-1)
B, L = tgt.shape
# Token Loss (只看 mask==1)
labels_tok = tgt.masked_fill(mask == 0, -100)
loss_tok = F.cross_entropy(logits.view(-1, V), labels_tok.view(-1), ignore_index=-100)
# Pad Loss (mask==0)
pad_pos = (mask == 0)
if pad_pos.any():
ce_pad = F.cross_entropy(logits.view(-1, V), tgt.new_full((B*L,), self.pad_id), reduction='none').view(B,L)
loss_pad = (ce_pad * pad_pos).sum() / (pad_pos.sum() + 1e-6)
else:
loss_pad = torch.tensor(0.0, device=self.device)
# Stop Loss
stop_pos = ((tgt == self.stop_id) & (mask == 1))
if stop_pos.any():
ce_stop = F.cross_entropy(logits.view(-1, V), tgt.view(-1), reduction='none').view(B,L)
loss_stop = (ce_stop * stop_pos).sum() / (stop_pos.sum() + 1e-6)
else:
loss_stop = torch.tensor(0.0, device=self.device)
# 合并 Loss
loss = loss_tok + 0.1 * loss_pad + 0.5 * loss_stop # 提高一点 stop 的权重
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
pbar.set_postfix(loss=loss.item())
return total_loss / len(self.loader)
def train_flow(self, optimizer):
self.flow.train()
self.ae.eval()
total_loss = 0
pbar = tqdm(self.loader, desc="Train Flow")
optimizer.zero_grad()
scale = getattr(self.ae, "latent_scale", 10.0)
for step, batch in enumerate(pbar):
src = batch['src_ids'].to(self.device)
src_mask = batch['src_mask'].to(self.device)
tgt = batch['tgt_ids'].to(self.device)
tgt_mask = batch['tgt_mask'].to(self.device)
with torch.no_grad():
z_bad = self.ae.encode(src, src_mask) # norm ~ scale
z_good = self.ae.encode(tgt, tgt_mask) # norm ~ scale
# Rectified Flow
bs = z_bad.shape[0]
t = torch.rand(bs, device=self.device).view(bs, 1, 1)
# Interpolation: Bad -> Good, modify-> push back to sphere
z_t_linear = (1 - t) * z_bad + t * z_good
## test before or after
# z_t = F.normalize(z_t_linear, p=2, dim=-1) * scale
z_t = z_t_linear
# Modify: pred_v to pred_x
# target_v = z_good - z_bad
# pred_v = self.flow(z_t, t.squeeze(), condition=z_bad)
# loss = F.mse_loss(pred_v, target_v)
# to predict z_good (Target)
pred_z1 = self.flow(z_t, t, condition=z_bad)
# 3) (强烈建议) 把输出也投影回同一球面,避免 off-manifold -> 都不要normalize
pred_z1 = pred_z1
# pred_z1 = F.normalize(pred_z1, p=2, dim=-1) * scale
# Loss 直接算与 z_good 的距离
## 修改:loss必须按照mask 算有效token
mse = (pred_z1 - z_good).pow(2).mean(dim=-1) # [B,L]
w = tgt_mask.float()
# stop 位置加权
stop_pos = ((tgt == self.stop_id) & (tgt_mask == 1))
w = w + stop_pos.float() * 2.0 # 让 SEP 位置权重更大(比如 +2)
loss = (mse * w).sum() / (w.sum() + 1e-6)
# loss = (mse * tgt_mask).sum() / (tgt_mask.sum() + 1e-6)
# loss = F.mse_loss(pred_z1, z_good)
loss = loss / self.cfg.grad_accum_steps
loss.backward()
if (step + 1) % self.cfg.grad_accum_steps == 0:
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item() * self.cfg.grad_accum_steps
pbar.set_postfix(loss=loss.item() * self.cfg.grad_accum_steps)
return total_loss / len(self.loader)