Spatial-BEATs / scripts /archive_old_checkpoints.py
dieKarotte's picture
Add files using upload-large-folder tool
4fdc640 verified
raw
history blame contribute delete
9.19 kB
#!/usr/bin/env python3
"""Archive older checkpoint .pt files while keeping key recovery points.
Default behavior is a dry run. Use --execute to actually move files.
"""
from __future__ import annotations
import argparse
import errno
import re
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
DEFAULT_ARCHIVE_ROOT = Path(
"/apdcephfs_cq10/share_1603164/user/schmittzhu/code/ckpts"
)
EXTRA_KEEP_RELATIVE = {
# run_ov1_v10_phase1_cls.sh defaults to this v9 ep3 checkpoint.
"checkpoints/spatial_beats_ov1_local_spatial_v9_ov123_exp/"
"03_ov123_top4/epoch_0003.pt",
}
EPOCH_RE = re.compile(r"^epoch[_-]?(\d+)\.pt$")
TRAILING_NUM_RE = re.compile(r"^(.*?)(\d+)\.pt$")
@dataclass(frozen=True)
class PtFile:
path: Path
size: int
mtime: float
def format_size(num_bytes: int) -> str:
units = ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]
size = float(num_bytes)
for unit in units:
if size < 1024 or unit == units[-1]:
if unit == "B":
return f"{num_bytes} B"
return f"{size:.1f} {unit}"
size /= 1024
raise AssertionError("unreachable")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Move old checkpoints to an archive directory, preserving original "
"relative paths. Keeps best.pt, last.pt, the max-numbered epoch "
"checkpoint per directory, and explicit keep-list exceptions."
)
)
parser.add_argument(
"--checkpoints-root",
type=Path,
default=Path("checkpoints"),
help="Checkpoint root to scan. Default: checkpoints",
)
parser.add_argument(
"--archive-root",
type=Path,
default=DEFAULT_ARCHIVE_ROOT,
help=f"Archive root. Default: {DEFAULT_ARCHIVE_ROOT}",
)
parser.add_argument(
"--execute",
action="store_true",
help="Actually move files. Without this flag, only prints a dry run.",
)
parser.add_argument(
"--max-depth",
type=int,
default=3,
help=(
"Maximum depth under the repo root to scan. Default: 3, matching "
"the current checkpoint layout."
),
)
parser.add_argument(
"--list-files",
action="store_true",
help="Print every move candidate, not just summary and top directories.",
)
parser.add_argument(
"--policy",
choices=("best-last-max", "minimal"),
default="best-last-max",
help=(
"Retention policy. best-last-max keeps best.pt, last.pt, and the "
"max-numbered epoch checkpoint per directory. minimal keeps best.pt "
"when present, otherwise last.pt, plus explicit keep-list exceptions."
),
)
return parser.parse_args()
def relative_to_cwd(path: Path) -> Path:
if not path.is_absolute():
return path
return path.relative_to(Path.cwd().resolve())
def collect_pt_files(root: Path, max_depth: int) -> dict[Path, list[PtFile]]:
by_dir: dict[Path, list[PtFile]] = {}
if not root.is_dir():
raise FileNotFoundError(f"checkpoint root does not exist: {root}")
result = subprocess.run(
[
"find",
str(root),
"-maxdepth",
str(max_depth),
"-type",
"f",
"-name",
"*.pt",
"-printf",
"%p\t%s\t%T@\n",
],
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode != 0:
raise RuntimeError(result.stderr.strip() or "find failed")
for line in result.stdout.splitlines():
path_text, size_text, mtime_text = line.rsplit("\t", 2)
path = Path(path_text)
item = PtFile(path=path, size=int(size_text), mtime=float(mtime_text))
by_dir.setdefault(path.parent, []).append(item)
return by_dir
def max_numbered_checkpoint(files: list[PtFile]) -> Path | None:
epoch_candidates: list[tuple[int, Path]] = []
other_candidates: list[tuple[int, Path]] = []
for item in files:
name = item.path.name
match = EPOCH_RE.match(name)
if match:
epoch_candidates.append((int(match.group(1)), item.path))
continue
match = TRAILING_NUM_RE.match(name)
if match and name not in {"best.pt", "last.pt"}:
other_candidates.append((int(match.group(2)), item.path))
if epoch_candidates:
return max(epoch_candidates, key=lambda pair: pair[0])[1]
if other_candidates:
return max(other_candidates, key=lambda pair: pair[0])[1]
return None
def build_keep_set(by_dir: dict[Path, list[PtFile]], policy: str) -> set[Path]:
keep: set[Path] = set()
extra_keep = {Path(item) for item in EXTRA_KEEP_RELATIVE}
for files in by_dir.values():
by_name = {item.path.name: item.path for item in files}
if policy == "best-last-max":
for name in ("best.pt", "last.pt"):
if name in by_name:
keep.add(by_name[name])
max_epoch = max_numbered_checkpoint(files)
if max_epoch is not None:
keep.add(max_epoch)
elif policy == "minimal":
if "best.pt" in by_name:
keep.add(by_name["best.pt"])
elif "last.pt" in by_name:
keep.add(by_name["last.pt"])
else:
max_epoch = max_numbered_checkpoint(files)
if max_epoch is not None:
keep.add(max_epoch)
else:
raise ValueError(f"unknown policy: {policy}")
for item in files:
if relative_to_cwd(item.path) in extra_keep:
keep.add(item.path)
return keep
def destination_for(source: Path, archive_root: Path) -> Path:
return archive_root / relative_to_cwd(source)
def move_one(source: Path, destination: Path) -> None:
if destination.exists():
raise FileExistsError(f"archive target already exists: {destination}")
destination.parent.mkdir(parents=True, exist_ok=True)
try:
source.rename(destination)
except OSError as exc:
if exc.errno == errno.EXDEV:
raise RuntimeError(
"source and archive are on different filesystems; refusing to "
"copy-then-delete automatically"
) from exc
raise
def main() -> int:
args = parse_args()
archive_root = args.archive_root.resolve()
by_dir = collect_pt_files(args.checkpoints_root, args.max_depth)
keep = build_keep_set(by_dir, args.policy)
all_files = [item for files in by_dir.values() for item in files]
candidates = [item for item in all_files if item.path not in keep]
total_size = sum(item.size for item in all_files)
keep_size = sum(item.size for item in all_files if item.path in keep)
move_size = sum(item.size for item in candidates)
print(f"Mode: {'EXECUTE' if args.execute else 'DRY-RUN'}")
print(f"Policy: {args.policy}")
print(f"Checkpoint dirs: {len(by_dir)}")
print(f"Total .pt files: {len(all_files)} ({format_size(total_size)})")
print(f"Keep files: {len(all_files) - len(candidates)} ({format_size(keep_size)})")
print(f"Move candidates: {len(candidates)} ({format_size(move_size)})")
print(f"Archive root: {archive_root}")
print()
print("Extra keep-list:")
for relpath in sorted(EXTRA_KEEP_RELATIVE):
path = Path.cwd() / relpath
print(f" KEEP {relpath} ({'exists' if path.exists() else 'missing'})")
print()
per_dir: list[tuple[int, int, Path]] = []
for directory, files in by_dir.items():
move_count = sum(1 for item in files if item.path not in keep)
move_bytes = sum(item.size for item in files if item.path not in keep)
if move_count:
per_dir.append((move_bytes, move_count, directory))
print("Top directories by archived size:")
for size, count, directory in sorted(per_dir, reverse=True)[:30]:
print(f" {format_size(size):>10} files={count:<3} {relative_to_cwd(directory)}")
print()
if args.list_files:
print("Move candidates:")
for item in sorted(candidates, key=lambda entry: str(entry.path)):
print(
f" {relative_to_cwd(item.path)} -> "
f"{destination_for(item.path, archive_root)}"
)
print()
if not args.execute:
print("Dry run only. Re-run with --execute to move files.")
return 0
moved = 0
moved_bytes = 0
for item in sorted(candidates, key=lambda entry: str(entry.path)):
destination = destination_for(item.path, archive_root)
move_one(item.path, destination)
moved += 1
moved_bytes += item.size
if moved % 25 == 0:
print(f"Moved {moved}/{len(candidates)} files ({format_size(moved_bytes)})")
print(f"Done. Moved {moved} files ({format_size(moved_bytes)}).")
return 0
if __name__ == "__main__":
sys.exit(main())