""" 模型检查点管理 处理模型的保存、加载和版本管理 """ import torch from pathlib import Path from typing import Dict, Optional, Any import json from datetime import datetime class CheckpointManager: """ 检查点管理器 自动管理模型保存、加载和最佳模型跟踪 """ def __init__( self, checkpoint_dir: str, max_keep: int = 5, metric_mode: str = 'min' ): """ Args: checkpoint_dir: 检查点保存目录 max_keep: 最多保留的检查点数量 metric_mode: 指标模式 ('min' 或 'max') """ self.checkpoint_dir = Path(checkpoint_dir) self.checkpoint_dir.mkdir(parents=True, exist_ok=True) self.max_keep = max_keep self.metric_mode = metric_mode self.checkpoints = [] # [(path, metric_value), ...] self.best_metric = float('inf') if metric_mode == 'min' else float('-inf') self.best_checkpoint = None # 加载已有检查点信息 self._load_checkpoint_info() def save( self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int, metric_value: float, extra_info: Optional[Dict] = None ) -> Path: """ 保存检查点 Args: model: 模型 optimizer: 优化器 epoch: 当前epoch metric_value: 验证指标值 extra_info: 额外信息 Returns: 保存的文件路径 """ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"checkpoint_epoch{epoch}_{timestamp}.pt" filepath = self.checkpoint_dir / filename # 准备保存内容 checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'metric_value': metric_value, 'timestamp': timestamp } if extra_info: checkpoint.update(extra_info) # 保存 torch.save(checkpoint, filepath) # 更新检查点列表 self.checkpoints.append((filepath, metric_value)) # 检查是否是最佳模型 is_best = self._is_best(metric_value) if is_best: self.best_metric = metric_value self.best_checkpoint = filepath # 保存最佳模型的副本 best_path = self.checkpoint_dir / "best_model.pt" torch.save(checkpoint, best_path) print(f"✨ New best model saved! Metric: {metric_value:.4f}") # 清理旧检查点 self._cleanup() # 保存检查点信息 self._save_checkpoint_info() return filepath def load( self, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, checkpoint_path: Optional[str] = None, load_best: bool = False ) -> Dict: """ 加载检查点 Args: model: 模型 optimizer: 优化器(可选) checkpoint_path: 检查点路径(可选,不指定则加载最新) load_best: 是否加载最佳模型 Returns: 检查点字典 """ if load_best: filepath = self.checkpoint_dir / "best_model.pt" elif checkpoint_path: filepath = Path(checkpoint_path) else: # 加载最新检查点 if not self.checkpoints: raise ValueError("No checkpoints found!") filepath = self.checkpoints[-1][0] if not filepath.exists(): raise FileNotFoundError(f"Checkpoint not found: {filepath}") print(f"Loading checkpoint from {filepath}") checkpoint = torch.load(filepath, map_location='cpu') # 加载模型权重 model.load_state_dict(checkpoint['model_state_dict']) # 加载优化器状态 if optimizer and 'optimizer_state_dict' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}") print(f"Metric value: {checkpoint.get('metric_value', 'N/A')}") return checkpoint def _is_best(self, metric_value: float) -> bool: """判断是否是最佳模型""" if self.metric_mode == 'min': return metric_value < self.best_metric else: return metric_value > self.best_metric def _cleanup(self): """清理旧检查点,只保留最新的max_keep个""" if len(self.checkpoints) <= self.max_keep: return # 按指标排序 sorted_checkpoints = sorted( self.checkpoints, key=lambda x: x[1], reverse=(self.metric_mode == 'max') ) # 保留最好的max_keep个 keep_checkpoints = sorted_checkpoints[:self.max_keep] remove_checkpoints = [ cp for cp in self.checkpoints if cp not in keep_checkpoints ] # 删除多余的文件(除了best_model.pt) for filepath, _ in remove_checkpoints: if filepath.exists() and filepath.name != "best_model.pt": filepath.unlink() print(f"Removed old checkpoint: {filepath.name}") self.checkpoints = keep_checkpoints def _save_checkpoint_info(self): """保存检查点元信息""" info = { 'checkpoints': [ {'path': str(cp[0]), 'metric': cp[1]} for cp in self.checkpoints ], 'best_checkpoint': str(self.best_checkpoint) if self.best_checkpoint else None, 'best_metric': self.best_metric, 'metric_mode': self.metric_mode } info_file = self.checkpoint_dir / "checkpoint_info.json" with open(info_file, 'w') as f: json.dump(info, f, indent=2) def _load_checkpoint_info(self): """加载检查点元信息""" info_file = self.checkpoint_dir / "checkpoint_info.json" if not info_file.exists(): return with open(info_file, 'r') as f: info = json.load(f) self.checkpoints = [ (Path(cp['path']), cp['metric']) for cp in info['checkpoints'] if Path(cp['path']).exists() ] if info['best_checkpoint'] and Path(info['best_checkpoint']).exists(): self.best_checkpoint = Path(info['best_checkpoint']) self.best_metric = info['best_metric']