| |
| """Archive older checkpoint .pt files while keeping key recovery points. |
| |
| Default behavior is a dry run. Use --execute to actually move files. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import errno |
| import re |
| import subprocess |
| import sys |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
|
|
| DEFAULT_ARCHIVE_ROOT = Path( |
| "/apdcephfs_cq10/share_1603164/user/schmittzhu/code/ckpts" |
| ) |
|
|
| EXTRA_KEEP_RELATIVE = { |
| |
| "checkpoints/spatial_beats_ov1_local_spatial_v9_ov123_exp/" |
| "03_ov123_top4/epoch_0003.pt", |
| } |
|
|
| EPOCH_RE = re.compile(r"^epoch[_-]?(\d+)\.pt$") |
| TRAILING_NUM_RE = re.compile(r"^(.*?)(\d+)\.pt$") |
|
|
|
|
| @dataclass(frozen=True) |
| class PtFile: |
| path: Path |
| size: int |
| mtime: float |
|
|
|
|
| def format_size(num_bytes: int) -> str: |
| units = ["B", "KiB", "MiB", "GiB", "TiB", "PiB"] |
| size = float(num_bytes) |
| for unit in units: |
| if size < 1024 or unit == units[-1]: |
| if unit == "B": |
| return f"{num_bytes} B" |
| return f"{size:.1f} {unit}" |
| size /= 1024 |
| raise AssertionError("unreachable") |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description=( |
| "Move old checkpoints to an archive directory, preserving original " |
| "relative paths. Keeps best.pt, last.pt, the max-numbered epoch " |
| "checkpoint per directory, and explicit keep-list exceptions." |
| ) |
| ) |
| parser.add_argument( |
| "--checkpoints-root", |
| type=Path, |
| default=Path("checkpoints"), |
| help="Checkpoint root to scan. Default: checkpoints", |
| ) |
| parser.add_argument( |
| "--archive-root", |
| type=Path, |
| default=DEFAULT_ARCHIVE_ROOT, |
| help=f"Archive root. Default: {DEFAULT_ARCHIVE_ROOT}", |
| ) |
| parser.add_argument( |
| "--execute", |
| action="store_true", |
| help="Actually move files. Without this flag, only prints a dry run.", |
| ) |
| parser.add_argument( |
| "--max-depth", |
| type=int, |
| default=3, |
| help=( |
| "Maximum depth under the repo root to scan. Default: 3, matching " |
| "the current checkpoint layout." |
| ), |
| ) |
| parser.add_argument( |
| "--list-files", |
| action="store_true", |
| help="Print every move candidate, not just summary and top directories.", |
| ) |
| parser.add_argument( |
| "--policy", |
| choices=("best-last-max", "minimal"), |
| default="best-last-max", |
| help=( |
| "Retention policy. best-last-max keeps best.pt, last.pt, and the " |
| "max-numbered epoch checkpoint per directory. minimal keeps best.pt " |
| "when present, otherwise last.pt, plus explicit keep-list exceptions." |
| ), |
| ) |
| return parser.parse_args() |
|
|
|
|
| def relative_to_cwd(path: Path) -> Path: |
| if not path.is_absolute(): |
| return path |
| return path.relative_to(Path.cwd().resolve()) |
|
|
|
|
| def collect_pt_files(root: Path, max_depth: int) -> dict[Path, list[PtFile]]: |
| by_dir: dict[Path, list[PtFile]] = {} |
| if not root.is_dir(): |
| raise FileNotFoundError(f"checkpoint root does not exist: {root}") |
|
|
| result = subprocess.run( |
| [ |
| "find", |
| str(root), |
| "-maxdepth", |
| str(max_depth), |
| "-type", |
| "f", |
| "-name", |
| "*.pt", |
| "-printf", |
| "%p\t%s\t%T@\n", |
| ], |
| check=False, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| text=True, |
| ) |
| if result.returncode != 0: |
| raise RuntimeError(result.stderr.strip() or "find failed") |
|
|
| for line in result.stdout.splitlines(): |
| path_text, size_text, mtime_text = line.rsplit("\t", 2) |
| path = Path(path_text) |
| item = PtFile(path=path, size=int(size_text), mtime=float(mtime_text)) |
| by_dir.setdefault(path.parent, []).append(item) |
|
|
| return by_dir |
|
|
|
|
| def max_numbered_checkpoint(files: list[PtFile]) -> Path | None: |
| epoch_candidates: list[tuple[int, Path]] = [] |
| other_candidates: list[tuple[int, Path]] = [] |
|
|
| for item in files: |
| name = item.path.name |
| match = EPOCH_RE.match(name) |
| if match: |
| epoch_candidates.append((int(match.group(1)), item.path)) |
| continue |
|
|
| match = TRAILING_NUM_RE.match(name) |
| if match and name not in {"best.pt", "last.pt"}: |
| other_candidates.append((int(match.group(2)), item.path)) |
|
|
| if epoch_candidates: |
| return max(epoch_candidates, key=lambda pair: pair[0])[1] |
| if other_candidates: |
| return max(other_candidates, key=lambda pair: pair[0])[1] |
| return None |
|
|
|
|
| def build_keep_set(by_dir: dict[Path, list[PtFile]], policy: str) -> set[Path]: |
| keep: set[Path] = set() |
| extra_keep = {Path(item) for item in EXTRA_KEEP_RELATIVE} |
|
|
| for files in by_dir.values(): |
| by_name = {item.path.name: item.path for item in files} |
| if policy == "best-last-max": |
| for name in ("best.pt", "last.pt"): |
| if name in by_name: |
| keep.add(by_name[name]) |
|
|
| max_epoch = max_numbered_checkpoint(files) |
| if max_epoch is not None: |
| keep.add(max_epoch) |
| elif policy == "minimal": |
| if "best.pt" in by_name: |
| keep.add(by_name["best.pt"]) |
| elif "last.pt" in by_name: |
| keep.add(by_name["last.pt"]) |
| else: |
| max_epoch = max_numbered_checkpoint(files) |
| if max_epoch is not None: |
| keep.add(max_epoch) |
| else: |
| raise ValueError(f"unknown policy: {policy}") |
|
|
| for item in files: |
| if relative_to_cwd(item.path) in extra_keep: |
| keep.add(item.path) |
|
|
| return keep |
|
|
|
|
| def destination_for(source: Path, archive_root: Path) -> Path: |
| return archive_root / relative_to_cwd(source) |
|
|
|
|
| def move_one(source: Path, destination: Path) -> None: |
| if destination.exists(): |
| raise FileExistsError(f"archive target already exists: {destination}") |
| destination.parent.mkdir(parents=True, exist_ok=True) |
| try: |
| source.rename(destination) |
| except OSError as exc: |
| if exc.errno == errno.EXDEV: |
| raise RuntimeError( |
| "source and archive are on different filesystems; refusing to " |
| "copy-then-delete automatically" |
| ) from exc |
| raise |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| archive_root = args.archive_root.resolve() |
| by_dir = collect_pt_files(args.checkpoints_root, args.max_depth) |
| keep = build_keep_set(by_dir, args.policy) |
|
|
| all_files = [item for files in by_dir.values() for item in files] |
| candidates = [item for item in all_files if item.path not in keep] |
| total_size = sum(item.size for item in all_files) |
| keep_size = sum(item.size for item in all_files if item.path in keep) |
| move_size = sum(item.size for item in candidates) |
|
|
| print(f"Mode: {'EXECUTE' if args.execute else 'DRY-RUN'}") |
| print(f"Policy: {args.policy}") |
| print(f"Checkpoint dirs: {len(by_dir)}") |
| print(f"Total .pt files: {len(all_files)} ({format_size(total_size)})") |
| print(f"Keep files: {len(all_files) - len(candidates)} ({format_size(keep_size)})") |
| print(f"Move candidates: {len(candidates)} ({format_size(move_size)})") |
| print(f"Archive root: {archive_root}") |
| print() |
|
|
| print("Extra keep-list:") |
| for relpath in sorted(EXTRA_KEEP_RELATIVE): |
| path = Path.cwd() / relpath |
| print(f" KEEP {relpath} ({'exists' if path.exists() else 'missing'})") |
| print() |
|
|
| per_dir: list[tuple[int, int, Path]] = [] |
| for directory, files in by_dir.items(): |
| move_count = sum(1 for item in files if item.path not in keep) |
| move_bytes = sum(item.size for item in files if item.path not in keep) |
| if move_count: |
| per_dir.append((move_bytes, move_count, directory)) |
|
|
| print("Top directories by archived size:") |
| for size, count, directory in sorted(per_dir, reverse=True)[:30]: |
| print(f" {format_size(size):>10} files={count:<3} {relative_to_cwd(directory)}") |
| print() |
|
|
| if args.list_files: |
| print("Move candidates:") |
| for item in sorted(candidates, key=lambda entry: str(entry.path)): |
| print( |
| f" {relative_to_cwd(item.path)} -> " |
| f"{destination_for(item.path, archive_root)}" |
| ) |
| print() |
|
|
| if not args.execute: |
| print("Dry run only. Re-run with --execute to move files.") |
| return 0 |
|
|
| moved = 0 |
| moved_bytes = 0 |
| for item in sorted(candidates, key=lambda entry: str(entry.path)): |
| destination = destination_for(item.path, archive_root) |
| move_one(item.path, destination) |
| moved += 1 |
| moved_bytes += item.size |
| if moved % 25 == 0: |
| print(f"Moved {moved}/{len(candidates)} files ({format_size(moved_bytes)})") |
|
|
| print(f"Done. Moved {moved} files ({format_size(moved_bytes)}).") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|