VLAlert / lkalert /utils /checkpoint.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
6.96 kB
"""
模型检查点管理
处理模型的保存、加载和版本管理
"""
import torch
from pathlib import Path
from typing import Dict, Optional, Any
import json
from datetime import datetime
class CheckpointManager:
"""
检查点管理器
自动管理模型保存、加载和最佳模型跟踪
"""
def __init__(
self,
checkpoint_dir: str,
max_keep: int = 5,
metric_mode: str = 'min'
):
"""
Args:
checkpoint_dir: 检查点保存目录
max_keep: 最多保留的检查点数量
metric_mode: 指标模式 ('min' 或 'max')
"""
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.max_keep = max_keep
self.metric_mode = metric_mode
self.checkpoints = [] # [(path, metric_value), ...]
self.best_metric = float('inf') if metric_mode == 'min' else float('-inf')
self.best_checkpoint = None
# 加载已有检查点信息
self._load_checkpoint_info()
def save(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
epoch: int,
metric_value: float,
extra_info: Optional[Dict] = None
) -> Path:
"""
保存检查点
Args:
model: 模型
optimizer: 优化器
epoch: 当前epoch
metric_value: 验证指标值
extra_info: 额外信息
Returns:
保存的文件路径
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"checkpoint_epoch{epoch}_{timestamp}.pt"
filepath = self.checkpoint_dir / filename
# 准备保存内容
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'metric_value': metric_value,
'timestamp': timestamp
}
if extra_info:
checkpoint.update(extra_info)
# 保存
torch.save(checkpoint, filepath)
# 更新检查点列表
self.checkpoints.append((filepath, metric_value))
# 检查是否是最佳模型
is_best = self._is_best(metric_value)
if is_best:
self.best_metric = metric_value
self.best_checkpoint = filepath
# 保存最佳模型的副本
best_path = self.checkpoint_dir / "best_model.pt"
torch.save(checkpoint, best_path)
print(f"✨ New best model saved! Metric: {metric_value:.4f}")
# 清理旧检查点
self._cleanup()
# 保存检查点信息
self._save_checkpoint_info()
return filepath
def load(
self,
model: torch.nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
checkpoint_path: Optional[str] = None,
load_best: bool = False
) -> Dict:
"""
加载检查点
Args:
model: 模型
optimizer: 优化器(可选)
checkpoint_path: 检查点路径(可选,不指定则加载最新)
load_best: 是否加载最佳模型
Returns:
检查点字典
"""
if load_best:
filepath = self.checkpoint_dir / "best_model.pt"
elif checkpoint_path:
filepath = Path(checkpoint_path)
else:
# 加载最新检查点
if not self.checkpoints:
raise ValueError("No checkpoints found!")
filepath = self.checkpoints[-1][0]
if not filepath.exists():
raise FileNotFoundError(f"Checkpoint not found: {filepath}")
print(f"Loading checkpoint from {filepath}")
checkpoint = torch.load(filepath, map_location='cpu')
# 加载模型权重
model.load_state_dict(checkpoint['model_state_dict'])
# 加载优化器状态
if optimizer and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
print(f"Metric value: {checkpoint.get('metric_value', 'N/A')}")
return checkpoint
def _is_best(self, metric_value: float) -> bool:
"""判断是否是最佳模型"""
if self.metric_mode == 'min':
return metric_value < self.best_metric
else:
return metric_value > self.best_metric
def _cleanup(self):
"""清理旧检查点,只保留最新的max_keep个"""
if len(self.checkpoints) <= self.max_keep:
return
# 按指标排序
sorted_checkpoints = sorted(
self.checkpoints,
key=lambda x: x[1],
reverse=(self.metric_mode == 'max')
)
# 保留最好的max_keep个
keep_checkpoints = sorted_checkpoints[:self.max_keep]
remove_checkpoints = [
cp for cp in self.checkpoints if cp not in keep_checkpoints
]
# 删除多余的文件(除了best_model.pt)
for filepath, _ in remove_checkpoints:
if filepath.exists() and filepath.name != "best_model.pt":
filepath.unlink()
print(f"Removed old checkpoint: {filepath.name}")
self.checkpoints = keep_checkpoints
def _save_checkpoint_info(self):
"""保存检查点元信息"""
info = {
'checkpoints': [
{'path': str(cp[0]), 'metric': cp[1]}
for cp in self.checkpoints
],
'best_checkpoint': str(self.best_checkpoint) if self.best_checkpoint else None,
'best_metric': self.best_metric,
'metric_mode': self.metric_mode
}
info_file = self.checkpoint_dir / "checkpoint_info.json"
with open(info_file, 'w') as f:
json.dump(info, f, indent=2)
def _load_checkpoint_info(self):
"""加载检查点元信息"""
info_file = self.checkpoint_dir / "checkpoint_info.json"
if not info_file.exists():
return
with open(info_file, 'r') as f:
info = json.load(f)
self.checkpoints = [
(Path(cp['path']), cp['metric'])
for cp in info['checkpoints']
if Path(cp['path']).exists()
]
if info['best_checkpoint'] and Path(info['best_checkpoint']).exists():
self.best_checkpoint = Path(info['best_checkpoint'])
self.best_metric = info['best_metric']