prompt_model / trainer.py
marisa0v0's picture
Upload prompt-conditioned mel→acc model (best checkpoint)
8b64619 verified
"""
Transformer模型训练器
支持分布式训练、混合精度、梯度累积等特性
训练日志会保存在TensorBoard中,使用命令查看: tensorboard --logdir=<log_dir>
"""
import os
import shutil
import time
import torch
from datetime import datetime
from typing import Optional
from accelerate import Accelerator, DistributedDataParallelKwargs
from transformers import get_cosine_schedule_with_warmup
from tqdm.auto import tqdm
class TransformerTrainer:
"""Transformer模型训练器
封装了模型训练的完整流程,包括:
- 优化器和学习率调度器初始化
- 分布式训练支持
- 训练和评估循环
- 模型检查点保存
- TensorBoard日志记录
"""
def __init__(self, config, model, train_dataloader, test_dataloader=None, resume_path=None):
self.config = config
self.model = model
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
self.global_step = 0
self.best_eval_loss = float("inf")
self.completed_epochs = 0
self.optimizer = self._setup_optimizer()
self.lr_scheduler = self._setup_lr_scheduler()
self.accelerator = self._initialize_accelerator()
self._prepare_training_components()
self.accelerator.init_trackers(self.config.tensorboard_log_name)
self.test_interval_steps = self._calculate_test_interval()
# 断点续训
if resume_path:
self._load_training_state(resume_path)
def _setup_optimizer(self) -> torch.optim.Optimizer:
"""配置优化器(bias / LayerNorm 不施加 weight decay)"""
no_decay_keywords = {"bias", "layernorm", "layer_norm", "norm"}
decay_params, no_decay_params = [], []
for name, param in self.model.named_parameters():
if not param.requires_grad:
continue
if any(kw in name.lower() for kw in no_decay_keywords):
no_decay_params.append(param)
else:
decay_params.append(param)
return torch.optim.AdamW(
[
{"params": decay_params, "weight_decay": 0.1},
{"params": no_decay_params, "weight_decay": 0.0},
],
lr=self.config.learning_rate,
betas=(0.9, 0.99),
)
def _setup_lr_scheduler(self):
"""配置学习率调度器"""
num_training_steps = int(
(len(self.train_dataloader) * self.config.num_epochs * 1.5)
/ self.config.gradient_accumulation_steps
)
num_warmup_steps = int(num_training_steps * self.config.lr_warmup_ratio)
return get_cosine_schedule_with_warmup(
optimizer=self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
)
def _initialize_accelerator(self) -> Accelerator:
"""初始化Accelerator用于分布式训练和混合精度"""
ddp_kwargs = DistributedDataParallelKwargs()
accelerator = Accelerator(
mixed_precision=self.config.mixed_precision,
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
log_with="tensorboard",
project_dir=self.config.tensorboard_log_dir,
device_placement=True,
kwargs_handlers=[ddp_kwargs]
)
# 创建日志目录
if accelerator.is_main_process and self.config.tensorboard_log_dir:
os.makedirs(self.config.tensorboard_log_dir, exist_ok=True)
return accelerator
def _prepare_training_components(self):
"""准备训练组件(处理分布式、混合精度等)"""
components = [self.model, self.optimizer, self.train_dataloader, self.lr_scheduler]
if self.test_dataloader is not None:
components.insert(3, self.test_dataloader)
prepared = self.accelerator.prepare(*components)
self.model, self.optimizer, self.train_dataloader, self.test_dataloader, self.lr_scheduler = prepared
else:
prepared = self.accelerator.prepare(*components)
self.model, self.optimizer, self.train_dataloader, self.lr_scheduler = prepared
def _calculate_test_interval(self) -> Optional[int]:
"""计算测试评估的间隔步数"""
if self.config.use_test_set and self.test_dataloader is not None:
interval = int(len(self.train_dataloader) * self.config.test_frequency)
print(f"测试间隔: 每 {interval} 步测试一次 (约{self.config.test_frequency}个epoch)")
return interval
return None
def train(self):
"""执行完整的训练流程"""
for epoch in range(self.completed_epochs, self.config.num_epochs):
self._train_one_epoch(epoch)
self._save_checkpoint_if_needed(epoch)
# 每个 epoch 结束保存可恢复的训练状态
self.completed_epochs = epoch + 1
self._save_training_state()
def _train_one_epoch(self, epoch: int):
self.model.train()
progress_bar = tqdm(total=len(self.train_dataloader), disable=not self.config.log)
progress_bar.set_description(f"Epoch {epoch}")
epoch_start_step = self.global_step
# 吞吐量计数器
tokens_since_log = 0
time_since_log = time.time()
for batch in self.train_dataloader:
loss = self._training_step(batch)
tokens_since_log += batch["input_ids"].numel()
# 记录日志(含吞吐量)
if self.global_step % 10 == 0:
elapsed = time.time() - time_since_log
tokens_per_sec = tokens_since_log / max(elapsed, 1e-6)
self._log_training_metrics(loss, tokens_per_sec)
tokens_since_log = 0
time_since_log = time.time()
if self._should_evaluate(epoch_start_step):
self._evaluate_test()
if self._should_save_checkpoint():
self._save_checkpoint(f"steps_{self.global_step}")
if self.config.log:
progress_bar.update(1)
progress_bar.set_postfix(
loss=loss.detach().item(),
lr=self.lr_scheduler.get_last_lr()[0]
)
self.global_step += 1
def _training_step(self, batch) -> torch.Tensor:
"""执行一个训练步骤
Args:
batch: 输入的批次数据
Returns:
当前批次的损失值
"""
with self.accelerator.accumulate(self.model):
outputs = self.model(
input_ids=batch["input_ids"],
labels=batch["labels"],
attention_mask=batch["attention_mask"],
position_ids=batch.get("position_ids"),
)
loss = outputs.loss
# 反向传播
self.accelerator.backward(loss)
# 梯度裁剪
if self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
# 优化器步进
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
return loss
def _log_training_metrics(self, loss: torch.Tensor, tokens_per_sec: float = 0):
"""记录训练指标到TensorBoard"""
metrics = {
"train/loss": loss.detach().item(),
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"train/tokens_per_sec": tokens_per_sec,
"train/step": self.global_step,
}
self.accelerator.log(metrics, step=self.global_step)
def _should_evaluate(self, epoch_start_step: int) -> bool:
"""判断是否需要进行测试评估"""
if self.test_interval_steps is None:
return False
return (
self.global_step % self.test_interval_steps == 0
and self.global_step > epoch_start_step
)
def _evaluate_test(self):
"""在测试集上评估模型"""
if self.test_dataloader is None:
return
self.accelerator.wait_for_everyone()
self.model.eval()
total_loss = 0.0
num_batches = 0
if self.accelerator.is_main_process:
print(f"\n开始测试集评估 (step {self.global_step})...")
with torch.no_grad():
eval_bar = tqdm(
self.test_dataloader,
desc="测试中",
disable=not self.accelerator.is_main_process
)
for batch in eval_bar:
outputs = self.model(
input_ids=batch["input_ids"],
labels=batch["labels"],
attention_mask=batch["attention_mask"],
position_ids=batch.get("position_ids"),
)
gathered_losses = self.accelerator.gather_for_metrics(
outputs.loss.detach().reshape(1)
)
total_loss += gathered_losses.sum().item()
num_batches += gathered_losses.numel()
# 计算平均损失和困惑度
avg_loss = total_loss / max(num_batches, 1)
perplexity = torch.exp(torch.tensor(avg_loss)).item()
# 记录测试指标
if self.config.test_save_results and self.accelerator.is_main_process:
test_metrics = {
"test/loss": avg_loss,
"test/perplexity": perplexity,
"test/step": self.global_step,
}
self.accelerator.log(test_metrics, step=self.global_step)
if self.accelerator.is_main_process:
print(f"测试集 - 损失: {avg_loss:.4f}, 困惑度: {perplexity:.4f}")
if avg_loss < self.best_eval_loss:
self.best_eval_loss = avg_loss
self._save_checkpoint("best_model")
print(f"best model 已更新: loss={self.best_eval_loss:.4f}")
self.model.train()
self.accelerator.wait_for_everyone()
def _should_save_checkpoint(self) -> bool:
"""判断是否需要保存检查点(按步数)"""
if self.config.save_steps is None or self.config.save_steps <= 0:
return False
return (
self.global_step % self.config.save_steps == 0
and self.accelerator.is_main_process
)
def _save_checkpoint_if_needed(self, epoch: int):
"""根据epoch判断是否需要保存检查点"""
if not self.accelerator.is_main_process:
return
should_save = (
(epoch + 1) % self.config.save_model_epochs == 0
or epoch == self.config.num_epochs - 1
)
if should_save:
self._save_checkpoint(f"epoch_{epoch}")
def _save_checkpoint(self, prefix: str):
"""保存模型权重(用于推理)"""
with torch.no_grad():
unwrapped_model = self.accelerator.unwrap_model(self.model)
# 兼容 torch.compile: 取出原始模型
if hasattr(unwrapped_model, '_orig_mod'):
unwrapped_model = unwrapped_model._orig_mod
if prefix == "best_model":
save_path = os.path.join(self.config.output_dir, prefix)
if os.path.isdir(save_path):
shutil.rmtree(save_path)
else:
timestamp = datetime.now().strftime("%m%d_%H%M")
save_path = f"{self.config.output_dir}/{prefix}_{timestamp}"
unwrapped_model.save_pretrained(save_path)
print(f"模型已保存至: {save_path}")
def _save_training_state(self):
"""保存完整训练状态(可断点续训)"""
save_path = os.path.join(self.config.output_dir, "latest")
# Accelerate 保存 model / optimizer / scheduler / scaler / rng
self.accelerator.save_state(save_path)
if self.accelerator.is_main_process:
torch.save({
'global_step': self.global_step,
'best_eval_loss': self.best_eval_loss,
'completed_epochs': self.completed_epochs,
}, os.path.join(save_path, 'training_meta.pt'))
print(f"训练状态已保存至: {save_path}")
def _load_training_state(self, load_path):
"""从检查点恢复完整训练状态"""
self.accelerator.load_state(load_path)
meta_path = os.path.join(load_path, 'training_meta.pt')
if os.path.exists(meta_path):
meta = torch.load(meta_path, map_location='cpu')
self.global_step = meta['global_step']
self.best_eval_loss = meta['best_eval_loss']
self.completed_epochs = meta['completed_epochs']
print(f"已恢复训练状态: epoch={self.completed_epochs}, step={self.global_step}, best_loss={self.best_eval_loss:.4f}")