| """ |
| 模型检查点管理 |
| 处理模型的保存、加载和版本管理 |
| """ |
|
|
| import torch |
| from pathlib import Path |
| from typing import Dict, Optional, Any |
| import json |
| from datetime import datetime |
|
|
|
|
| class CheckpointManager: |
| """ |
| 检查点管理器 |
| 自动管理模型保存、加载和最佳模型跟踪 |
| """ |
| |
| def __init__( |
| self, |
| checkpoint_dir: str, |
| max_keep: int = 5, |
| metric_mode: str = 'min' |
| ): |
| """ |
| Args: |
| checkpoint_dir: 检查点保存目录 |
| max_keep: 最多保留的检查点数量 |
| metric_mode: 指标模式 ('min' 或 'max') |
| """ |
| self.checkpoint_dir = Path(checkpoint_dir) |
| self.checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| |
| self.max_keep = max_keep |
| self.metric_mode = metric_mode |
| |
| self.checkpoints = [] |
| self.best_metric = float('inf') if metric_mode == 'min' else float('-inf') |
| self.best_checkpoint = None |
| |
| |
| self._load_checkpoint_info() |
| |
| def save( |
| self, |
| model: torch.nn.Module, |
| optimizer: torch.optim.Optimizer, |
| epoch: int, |
| metric_value: float, |
| extra_info: Optional[Dict] = None |
| ) -> Path: |
| """ |
| 保存检查点 |
| |
| Args: |
| model: 模型 |
| optimizer: 优化器 |
| epoch: 当前epoch |
| metric_value: 验证指标值 |
| extra_info: 额外信息 |
| |
| Returns: |
| 保存的文件路径 |
| """ |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| filename = f"checkpoint_epoch{epoch}_{timestamp}.pt" |
| filepath = self.checkpoint_dir / filename |
| |
| |
| checkpoint = { |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'metric_value': metric_value, |
| 'timestamp': timestamp |
| } |
| |
| if extra_info: |
| checkpoint.update(extra_info) |
| |
| |
| torch.save(checkpoint, filepath) |
| |
| |
| self.checkpoints.append((filepath, metric_value)) |
| |
| |
| is_best = self._is_best(metric_value) |
| if is_best: |
| self.best_metric = metric_value |
| self.best_checkpoint = filepath |
| |
| best_path = self.checkpoint_dir / "best_model.pt" |
| torch.save(checkpoint, best_path) |
| print(f"✨ New best model saved! Metric: {metric_value:.4f}") |
| |
| |
| self._cleanup() |
| |
| |
| self._save_checkpoint_info() |
| |
| return filepath |
| |
| def load( |
| self, |
| model: torch.nn.Module, |
| optimizer: Optional[torch.optim.Optimizer] = None, |
| checkpoint_path: Optional[str] = None, |
| load_best: bool = False |
| ) -> Dict: |
| """ |
| 加载检查点 |
| |
| Args: |
| model: 模型 |
| optimizer: 优化器(可选) |
| checkpoint_path: 检查点路径(可选,不指定则加载最新) |
| load_best: 是否加载最佳模型 |
| |
| Returns: |
| 检查点字典 |
| """ |
| if load_best: |
| filepath = self.checkpoint_dir / "best_model.pt" |
| elif checkpoint_path: |
| filepath = Path(checkpoint_path) |
| else: |
| |
| if not self.checkpoints: |
| raise ValueError("No checkpoints found!") |
| filepath = self.checkpoints[-1][0] |
| |
| if not filepath.exists(): |
| raise FileNotFoundError(f"Checkpoint not found: {filepath}") |
| |
| print(f"Loading checkpoint from {filepath}") |
| checkpoint = torch.load(filepath, map_location='cpu') |
| |
| |
| model.load_state_dict(checkpoint['model_state_dict']) |
| |
| |
| if optimizer and 'optimizer_state_dict' in checkpoint: |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| |
| print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}") |
| print(f"Metric value: {checkpoint.get('metric_value', 'N/A')}") |
| |
| return checkpoint |
| |
| def _is_best(self, metric_value: float) -> bool: |
| """判断是否是最佳模型""" |
| if self.metric_mode == 'min': |
| return metric_value < self.best_metric |
| else: |
| return metric_value > self.best_metric |
| |
| def _cleanup(self): |
| """清理旧检查点,只保留最新的max_keep个""" |
| if len(self.checkpoints) <= self.max_keep: |
| return |
| |
| |
| sorted_checkpoints = sorted( |
| self.checkpoints, |
| key=lambda x: x[1], |
| reverse=(self.metric_mode == 'max') |
| ) |
| |
| |
| keep_checkpoints = sorted_checkpoints[:self.max_keep] |
| remove_checkpoints = [ |
| cp for cp in self.checkpoints if cp not in keep_checkpoints |
| ] |
| |
| |
| for filepath, _ in remove_checkpoints: |
| if filepath.exists() and filepath.name != "best_model.pt": |
| filepath.unlink() |
| print(f"Removed old checkpoint: {filepath.name}") |
| |
| self.checkpoints = keep_checkpoints |
| |
| def _save_checkpoint_info(self): |
| """保存检查点元信息""" |
| info = { |
| 'checkpoints': [ |
| {'path': str(cp[0]), 'metric': cp[1]} |
| for cp in self.checkpoints |
| ], |
| 'best_checkpoint': str(self.best_checkpoint) if self.best_checkpoint else None, |
| 'best_metric': self.best_metric, |
| 'metric_mode': self.metric_mode |
| } |
| |
| info_file = self.checkpoint_dir / "checkpoint_info.json" |
| with open(info_file, 'w') as f: |
| json.dump(info, f, indent=2) |
| |
| def _load_checkpoint_info(self): |
| """加载检查点元信息""" |
| info_file = self.checkpoint_dir / "checkpoint_info.json" |
| if not info_file.exists(): |
| return |
| |
| with open(info_file, 'r') as f: |
| info = json.load(f) |
| |
| self.checkpoints = [ |
| (Path(cp['path']), cp['metric']) |
| for cp in info['checkpoints'] |
| if Path(cp['path']).exists() |
| ] |
| |
| if info['best_checkpoint'] and Path(info['best_checkpoint']).exists(): |
| self.best_checkpoint = Path(info['best_checkpoint']) |
| self.best_metric = info['best_metric'] |