| """
|
| Transformer模型训练器
|
| 支持分布式训练、混合精度、梯度累积等特性
|
| 训练日志会保存在TensorBoard中,使用命令查看: tensorboard --logdir=<log_dir>
|
| """
|
| import os
|
| import shutil
|
| import time
|
| import torch
|
| from datetime import datetime
|
| from typing import Optional
|
| from accelerate import Accelerator, DistributedDataParallelKwargs
|
| from transformers import get_cosine_schedule_with_warmup
|
| from tqdm.auto import tqdm
|
|
|
| class TransformerTrainer:
|
| """Transformer模型训练器
|
|
|
| 封装了模型训练的完整流程,包括:
|
| - 优化器和学习率调度器初始化
|
| - 分布式训练支持
|
| - 训练和评估循环
|
| - 模型检查点保存
|
| - TensorBoard日志记录
|
| """
|
|
|
| def __init__(self, config, model, train_dataloader, test_dataloader=None, resume_path=None):
|
| self.config = config
|
| self.model = model
|
| self.train_dataloader = train_dataloader
|
| self.test_dataloader = test_dataloader
|
| self.global_step = 0
|
| self.best_eval_loss = float("inf")
|
| self.completed_epochs = 0
|
|
|
| self.optimizer = self._setup_optimizer()
|
| self.lr_scheduler = self._setup_lr_scheduler()
|
| self.accelerator = self._initialize_accelerator()
|
| self._prepare_training_components()
|
| self.accelerator.init_trackers(self.config.tensorboard_log_name)
|
| self.test_interval_steps = self._calculate_test_interval()
|
|
|
|
|
| if resume_path:
|
| self._load_training_state(resume_path)
|
|
|
| def _setup_optimizer(self) -> torch.optim.Optimizer:
|
| """配置优化器(bias / LayerNorm 不施加 weight decay)"""
|
| no_decay_keywords = {"bias", "layernorm", "layer_norm", "norm"}
|
| decay_params, no_decay_params = [], []
|
| for name, param in self.model.named_parameters():
|
| if not param.requires_grad:
|
| continue
|
| if any(kw in name.lower() for kw in no_decay_keywords):
|
| no_decay_params.append(param)
|
| else:
|
| decay_params.append(param)
|
| return torch.optim.AdamW(
|
| [
|
| {"params": decay_params, "weight_decay": 0.1},
|
| {"params": no_decay_params, "weight_decay": 0.0},
|
| ],
|
| lr=self.config.learning_rate,
|
| betas=(0.9, 0.99),
|
| )
|
|
|
| def _setup_lr_scheduler(self):
|
| """配置学习率调度器"""
|
| num_training_steps = int(
|
| (len(self.train_dataloader) * self.config.num_epochs * 1.5)
|
| / self.config.gradient_accumulation_steps
|
| )
|
| num_warmup_steps = int(num_training_steps * self.config.lr_warmup_ratio)
|
| return get_cosine_schedule_with_warmup(
|
| optimizer=self.optimizer,
|
| num_warmup_steps=num_warmup_steps,
|
| num_training_steps=num_training_steps,
|
| )
|
|
|
| def _initialize_accelerator(self) -> Accelerator:
|
| """初始化Accelerator用于分布式训练和混合精度"""
|
| ddp_kwargs = DistributedDataParallelKwargs()
|
| accelerator = Accelerator(
|
| mixed_precision=self.config.mixed_precision,
|
| gradient_accumulation_steps=self.config.gradient_accumulation_steps,
|
| log_with="tensorboard",
|
| project_dir=self.config.tensorboard_log_dir,
|
| device_placement=True,
|
| kwargs_handlers=[ddp_kwargs]
|
| )
|
|
|
|
|
| if accelerator.is_main_process and self.config.tensorboard_log_dir:
|
| os.makedirs(self.config.tensorboard_log_dir, exist_ok=True)
|
|
|
| return accelerator
|
|
|
| def _prepare_training_components(self):
|
| """准备训练组件(处理分布式、混合精度等)"""
|
| components = [self.model, self.optimizer, self.train_dataloader, self.lr_scheduler]
|
|
|
| if self.test_dataloader is not None:
|
| components.insert(3, self.test_dataloader)
|
| prepared = self.accelerator.prepare(*components)
|
| self.model, self.optimizer, self.train_dataloader, self.test_dataloader, self.lr_scheduler = prepared
|
| else:
|
| prepared = self.accelerator.prepare(*components)
|
| self.model, self.optimizer, self.train_dataloader, self.lr_scheduler = prepared
|
|
|
| def _calculate_test_interval(self) -> Optional[int]:
|
| """计算测试评估的间隔步数"""
|
| if self.config.use_test_set and self.test_dataloader is not None:
|
| interval = int(len(self.train_dataloader) * self.config.test_frequency)
|
| print(f"测试间隔: 每 {interval} 步测试一次 (约{self.config.test_frequency}个epoch)")
|
| return interval
|
| return None
|
|
|
| def train(self):
|
| """执行完整的训练流程"""
|
| for epoch in range(self.completed_epochs, self.config.num_epochs):
|
| self._train_one_epoch(epoch)
|
| self._save_checkpoint_if_needed(epoch)
|
|
|
| self.completed_epochs = epoch + 1
|
| self._save_training_state()
|
|
|
| def _train_one_epoch(self, epoch: int):
|
| self.model.train()
|
| progress_bar = tqdm(total=len(self.train_dataloader), disable=not self.config.log)
|
| progress_bar.set_description(f"Epoch {epoch}")
|
| epoch_start_step = self.global_step
|
|
|
|
|
| tokens_since_log = 0
|
| time_since_log = time.time()
|
|
|
| for batch in self.train_dataloader:
|
| loss = self._training_step(batch)
|
| tokens_since_log += batch["input_ids"].numel()
|
|
|
|
|
| if self.global_step % 10 == 0:
|
| elapsed = time.time() - time_since_log
|
| tokens_per_sec = tokens_since_log / max(elapsed, 1e-6)
|
| self._log_training_metrics(loss, tokens_per_sec)
|
| tokens_since_log = 0
|
| time_since_log = time.time()
|
|
|
| if self._should_evaluate(epoch_start_step):
|
| self._evaluate_test()
|
|
|
| if self._should_save_checkpoint():
|
| self._save_checkpoint(f"steps_{self.global_step}")
|
|
|
| if self.config.log:
|
| progress_bar.update(1)
|
| progress_bar.set_postfix(
|
| loss=loss.detach().item(),
|
| lr=self.lr_scheduler.get_last_lr()[0]
|
| )
|
| self.global_step += 1
|
|
|
| def _training_step(self, batch) -> torch.Tensor:
|
| """执行一个训练步骤
|
|
|
| Args:
|
| batch: 输入的批次数据
|
|
|
| Returns:
|
| 当前批次的损失值
|
| """
|
| with self.accelerator.accumulate(self.model):
|
| outputs = self.model(
|
| input_ids=batch["input_ids"],
|
| labels=batch["labels"],
|
| attention_mask=batch["attention_mask"],
|
| position_ids=batch.get("position_ids"),
|
| )
|
| loss = outputs.loss
|
|
|
|
|
| self.accelerator.backward(loss)
|
|
|
|
|
| if self.accelerator.sync_gradients:
|
| self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
|
|
|
|
|
| self.optimizer.step()
|
| self.lr_scheduler.step()
|
| self.optimizer.zero_grad()
|
|
|
| return loss
|
|
|
| def _log_training_metrics(self, loss: torch.Tensor, tokens_per_sec: float = 0):
|
| """记录训练指标到TensorBoard"""
|
| metrics = {
|
| "train/loss": loss.detach().item(),
|
| "train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
| "train/tokens_per_sec": tokens_per_sec,
|
| "train/step": self.global_step,
|
| }
|
| self.accelerator.log(metrics, step=self.global_step)
|
|
|
| def _should_evaluate(self, epoch_start_step: int) -> bool:
|
| """判断是否需要进行测试评估"""
|
| if self.test_interval_steps is None:
|
| return False
|
|
|
| return (
|
| self.global_step % self.test_interval_steps == 0
|
| and self.global_step > epoch_start_step
|
| )
|
|
|
| def _evaluate_test(self):
|
| """在测试集上评估模型"""
|
| if self.test_dataloader is None:
|
| return
|
|
|
| self.accelerator.wait_for_everyone()
|
| self.model.eval()
|
| total_loss = 0.0
|
| num_batches = 0
|
|
|
| if self.accelerator.is_main_process:
|
| print(f"\n开始测试集评估 (step {self.global_step})...")
|
|
|
| with torch.no_grad():
|
| eval_bar = tqdm(
|
| self.test_dataloader,
|
| desc="测试中",
|
| disable=not self.accelerator.is_main_process
|
| )
|
| for batch in eval_bar:
|
| outputs = self.model(
|
| input_ids=batch["input_ids"],
|
| labels=batch["labels"],
|
| attention_mask=batch["attention_mask"],
|
| position_ids=batch.get("position_ids"),
|
| )
|
| gathered_losses = self.accelerator.gather_for_metrics(
|
| outputs.loss.detach().reshape(1)
|
| )
|
| total_loss += gathered_losses.sum().item()
|
| num_batches += gathered_losses.numel()
|
|
|
|
|
| avg_loss = total_loss / max(num_batches, 1)
|
| perplexity = torch.exp(torch.tensor(avg_loss)).item()
|
|
|
|
|
| if self.config.test_save_results and self.accelerator.is_main_process:
|
| test_metrics = {
|
| "test/loss": avg_loss,
|
| "test/perplexity": perplexity,
|
| "test/step": self.global_step,
|
| }
|
| self.accelerator.log(test_metrics, step=self.global_step)
|
|
|
| if self.accelerator.is_main_process:
|
| print(f"测试集 - 损失: {avg_loss:.4f}, 困惑度: {perplexity:.4f}")
|
|
|
| if avg_loss < self.best_eval_loss:
|
| self.best_eval_loss = avg_loss
|
| self._save_checkpoint("best_model")
|
| print(f"best model 已更新: loss={self.best_eval_loss:.4f}")
|
|
|
| self.model.train()
|
| self.accelerator.wait_for_everyone()
|
|
|
| def _should_save_checkpoint(self) -> bool:
|
| """判断是否需要保存检查点(按步数)"""
|
| if self.config.save_steps is None or self.config.save_steps <= 0:
|
| return False
|
|
|
| return (
|
| self.global_step % self.config.save_steps == 0
|
| and self.accelerator.is_main_process
|
| )
|
|
|
| def _save_checkpoint_if_needed(self, epoch: int):
|
| """根据epoch判断是否需要保存检查点"""
|
| if not self.accelerator.is_main_process:
|
| return
|
|
|
| should_save = (
|
| (epoch + 1) % self.config.save_model_epochs == 0
|
| or epoch == self.config.num_epochs - 1
|
| )
|
|
|
| if should_save:
|
| self._save_checkpoint(f"epoch_{epoch}")
|
|
|
| def _save_checkpoint(self, prefix: str):
|
| """保存模型权重(用于推理)"""
|
| with torch.no_grad():
|
| unwrapped_model = self.accelerator.unwrap_model(self.model)
|
|
|
| if hasattr(unwrapped_model, '_orig_mod'):
|
| unwrapped_model = unwrapped_model._orig_mod
|
| if prefix == "best_model":
|
| save_path = os.path.join(self.config.output_dir, prefix)
|
| if os.path.isdir(save_path):
|
| shutil.rmtree(save_path)
|
| else:
|
| timestamp = datetime.now().strftime("%m%d_%H%M")
|
| save_path = f"{self.config.output_dir}/{prefix}_{timestamp}"
|
| unwrapped_model.save_pretrained(save_path)
|
| print(f"模型已保存至: {save_path}")
|
|
|
| def _save_training_state(self):
|
| """保存完整训练状态(可断点续训)"""
|
| save_path = os.path.join(self.config.output_dir, "latest")
|
|
|
| self.accelerator.save_state(save_path)
|
| if self.accelerator.is_main_process:
|
| torch.save({
|
| 'global_step': self.global_step,
|
| 'best_eval_loss': self.best_eval_loss,
|
| 'completed_epochs': self.completed_epochs,
|
| }, os.path.join(save_path, 'training_meta.pt'))
|
| print(f"训练状态已保存至: {save_path}")
|
|
|
| def _load_training_state(self, load_path):
|
| """从检查点恢复完整训练状态"""
|
| self.accelerator.load_state(load_path)
|
| meta_path = os.path.join(load_path, 'training_meta.pt')
|
| if os.path.exists(meta_path):
|
| meta = torch.load(meta_path, map_location='cpu')
|
| self.global_step = meta['global_step']
|
| self.best_eval_loss = meta['best_eval_loss']
|
| self.completed_epochs = meta['completed_epochs']
|
| print(f"已恢复训练状态: epoch={self.completed_epochs}, step={self.global_step}, best_loss={self.best_eval_loss:.4f}")
|
|
|