ctm-energy-based-halting / verify_dashboard.py
Uday's picture
Added model training artifact dashboard and saving artifacts
00d1de8
raw
history blame
2.55 kB
import os
import json
import subprocess
import time
import shutil
def verify():
print("Starting verification...")
# Clean up previous logs
if os.path.exists('logs/scratch'):
shutil.rmtree('logs/scratch')
# Run training script for a few iterations
# We use a small model (ff) and cifar10 for speed, with minimal iterations
cmd = [
"pixi", "run", "accelerate", "launch", "--cpu", "tasks/image_classification/train_energy.py",
"--model", "ff",
"--dataset", "cifar10",
"--batch_size", "4",
"--training_iterations", "5", # Run for 5 iterations
"--track_every", "2", # Track every 2 iterations to ensure we get logs
"--save_every", "2", # Save every 2 iterations
"--log_dir", "logs/scratch",
"--device", "-1" # Use CPU for verification to avoid GPU issues if any
]
print(f"Running command: {' '.join(cmd)}")
try:
subprocess.run(cmd, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
print("Training failed!")
print(e.stderr.decode())
return
print("Training finished. Checking files...")
# Check status.json
if os.path.exists('logs/scratch/status.json'):
print("[PASS] status.json exists")
with open('logs/scratch/status.json', 'r') as f:
data = json.load(f)
print(f" - Iteration: {data.get('iteration')}")
print(f" - Train Loss: {data.get('train_loss')}")
else:
print("[FAIL] status.json missing")
# Check artifacts.zip
if os.path.exists('logs/scratch/artifacts.zip'):
print("[PASS] artifacts.zip exists")
else:
print("[FAIL] artifacts.zip missing")
# Check plots
if os.path.exists('logs/scratch/losses.png'):
print("[PASS] losses.png exists")
else:
print("[FAIL] losses.png missing")
if os.path.exists('logs/scratch/accuracies.png'):
print("[PASS] accuracies.png exists")
else:
print("[FAIL] accuracies.png missing")
# Check index.html content (simple check)
if os.path.exists('index.html'):
with open('index.html', 'r') as f:
content = f.read()
if 'CTM Training Dashboard' in content and 'status.json' in content:
print("[PASS] index.html looks correct")
else:
print("[FAIL] index.html content incorrect")
else:
print("[FAIL] index.html missing")
if __name__ == "__main__":
verify()