|
|
import os |
|
|
import json |
|
|
import subprocess |
|
|
import time |
|
|
import shutil |
|
|
|
|
|
def verify(): |
|
|
print("Starting verification...") |
|
|
|
|
|
|
|
|
if os.path.exists('logs/scratch'): |
|
|
shutil.rmtree('logs/scratch') |
|
|
|
|
|
|
|
|
|
|
|
cmd = [ |
|
|
"pixi", "run", "accelerate", "launch", "--cpu", "tasks/image_classification/train_energy.py", |
|
|
"--model", "ff", |
|
|
"--dataset", "cifar10", |
|
|
"--batch_size", "4", |
|
|
"--training_iterations", "5", |
|
|
"--track_every", "2", |
|
|
"--save_every", "2", |
|
|
"--log_dir", "logs/scratch", |
|
|
"--device", "-1" |
|
|
] |
|
|
|
|
|
print(f"Running command: {' '.join(cmd)}") |
|
|
try: |
|
|
subprocess.run(cmd, check=True, capture_output=True) |
|
|
except subprocess.CalledProcessError as e: |
|
|
print("Training failed!") |
|
|
print(e.stderr.decode()) |
|
|
return |
|
|
|
|
|
print("Training finished. Checking files...") |
|
|
|
|
|
|
|
|
if os.path.exists('logs/scratch/status.json'): |
|
|
print("[PASS] status.json exists") |
|
|
with open('logs/scratch/status.json', 'r') as f: |
|
|
data = json.load(f) |
|
|
print(f" - Iteration: {data.get('iteration')}") |
|
|
print(f" - Train Loss: {data.get('train_loss')}") |
|
|
else: |
|
|
print("[FAIL] status.json missing") |
|
|
|
|
|
|
|
|
if os.path.exists('logs/scratch/artifacts.zip'): |
|
|
print("[PASS] artifacts.zip exists") |
|
|
else: |
|
|
print("[FAIL] artifacts.zip missing") |
|
|
|
|
|
|
|
|
if os.path.exists('logs/scratch/losses.png'): |
|
|
print("[PASS] losses.png exists") |
|
|
else: |
|
|
print("[FAIL] losses.png missing") |
|
|
|
|
|
if os.path.exists('logs/scratch/accuracies.png'): |
|
|
print("[PASS] accuracies.png exists") |
|
|
else: |
|
|
print("[FAIL] accuracies.png missing") |
|
|
|
|
|
|
|
|
if os.path.exists('index.html'): |
|
|
with open('index.html', 'r') as f: |
|
|
content = f.read() |
|
|
if 'CTM Training Dashboard' in content and 'status.json' in content: |
|
|
print("[PASS] index.html looks correct") |
|
|
else: |
|
|
print("[FAIL] index.html content incorrect") |
|
|
else: |
|
|
print("[FAIL] index.html missing") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
verify() |
|
|
|