import os import json import subprocess import time import shutil def verify(): print("Starting verification...") # Clean up previous logs if os.path.exists('logs/scratch'): shutil.rmtree('logs/scratch') # Run training script for a few iterations # We use a small model (ff) and cifar10 for speed, with minimal iterations cmd = [ "pixi", "run", "accelerate", "launch", "--cpu", "tasks/image_classification/train_energy.py", "--model", "ff", "--dataset", "cifar10", "--batch_size", "4", "--training_iterations", "5", # Run for 5 iterations "--track_every", "2", # Track every 2 iterations to ensure we get logs "--save_every", "2", # Save every 2 iterations "--log_dir", "logs/scratch", "--device", "-1" # Use CPU for verification to avoid GPU issues if any ] print(f"Running command: {' '.join(cmd)}") try: subprocess.run(cmd, check=True, capture_output=True) except subprocess.CalledProcessError as e: print("Training failed!") print(e.stderr.decode()) return print("Training finished. Checking files...") # Check status.json if os.path.exists('logs/scratch/status.json'): print("[PASS] status.json exists") with open('logs/scratch/status.json', 'r') as f: data = json.load(f) print(f" - Iteration: {data.get('iteration')}") print(f" - Train Loss: {data.get('train_loss')}") else: print("[FAIL] status.json missing") # Check artifacts.zip if os.path.exists('logs/scratch/artifacts.zip'): print("[PASS] artifacts.zip exists") else: print("[FAIL] artifacts.zip missing") # Check plots if os.path.exists('logs/scratch/losses.png'): print("[PASS] losses.png exists") else: print("[FAIL] losses.png missing") if os.path.exists('logs/scratch/accuracies.png'): print("[PASS] accuracies.png exists") else: print("[FAIL] accuracies.png missing") # Check index.html content (simple check) if os.path.exists('index.html'): with open('index.html', 'r') as f: content = f.read() if 'CTM Training Dashboard' in content and 'status.json' in content: print("[PASS] index.html looks correct") else: print("[FAIL] index.html content incorrect") else: print("[FAIL] index.html missing") if __name__ == "__main__": verify()