Uday commited on
Commit
00d1de8
·
1 Parent(s): 32e3089

Added model training artifact dashboard and saving artifacts

Browse files
index.html CHANGED
@@ -1,27 +1,296 @@
1
  <!DOCTYPE html>
2
- <html>
3
  <head>
4
- <title>CTM Training Status</title>
 
 
5
  <style>
 
 
 
 
 
 
 
 
 
6
  body {
7
- font-family: sans-serif;
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  text-align: center;
9
- padding: 50px;
10
  }
 
11
  h1 {
12
- color: #333;
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  }
14
- p {
 
 
 
 
 
 
 
15
  color: #666;
16
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  </style>
18
  </head>
19
  <body>
20
- <h1>Training in Progress</h1>
21
- <p>
22
- The Continuous Thought Machine energy-based halting experiment is
23
- currently training.
24
- </p>
25
- <p>Please check the <strong>Logs</strong> tab for real-time updates.</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  </body>
27
  </html>
 
1
  <!DOCTYPE html>
2
+ <html lang="en">
3
  <head>
4
+ <meta charset="UTF-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
+ <title>CTM Training Dashboard</title>
7
  <style>
8
+ :root {
9
+ --bg-color: #f4f4f9;
10
+ --card-bg: #ffffff;
11
+ --text-color: #333;
12
+ --accent-color: #4a90e2;
13
+ --success-color: #2ecc71;
14
+ --font-family: "Segoe UI", Tahoma, Geneva, Verdana, sans-serif;
15
+ }
16
+
17
  body {
18
+ font-family: var(--font-family);
19
+ background-color: var(--bg-color);
20
+ color: var(--text-color);
21
+ margin: 0;
22
+ padding: 20px;
23
+ line-height: 1.6;
24
+ }
25
+
26
+ .container {
27
+ max-width: 1200px;
28
+ margin: 0 auto;
29
+ }
30
+
31
+ header {
32
  text-align: center;
33
+ margin-bottom: 40px;
34
  }
35
+
36
  h1 {
37
+ color: var(--accent-color);
38
+ margin-bottom: 10px;
39
+ }
40
+
41
+ .status-card {
42
+ background: var(--card-bg);
43
+ padding: 20px;
44
+ border-radius: 8px;
45
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
46
+ margin-bottom: 30px;
47
+ display: grid;
48
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
49
+ gap: 20px;
50
+ text-align: center;
51
  }
52
+
53
+ .metric {
54
+ display: flex;
55
+ flex-direction: column;
56
+ }
57
+
58
+ .metric-label {
59
+ font-size: 0.9em;
60
  color: #666;
61
  }
62
+
63
+ .metric-value {
64
+ font-size: 1.5em;
65
+ font-weight: bold;
66
+ color: var(--text-color);
67
+ }
68
+
69
+ .plots-grid {
70
+ display: grid;
71
+ grid-template-columns: repeat(auto-fit, minmax(500px, 1fr));
72
+ gap: 20px;
73
+ margin-bottom: 30px;
74
+ }
75
+
76
+ .plot-card {
77
+ background: var(--card-bg);
78
+ padding: 15px;
79
+ border-radius: 8px;
80
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
81
+ }
82
+
83
+ .plot-card img {
84
+ width: 100%;
85
+ height: auto;
86
+ border-radius: 4px;
87
+ }
88
+
89
+ .artifacts-section {
90
+ background: var(--card-bg);
91
+ padding: 20px;
92
+ border-radius: 8px;
93
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
94
+ text-align: center;
95
+ margin-bottom: 30px;
96
+ }
97
+
98
+ .btn {
99
+ display: inline-block;
100
+ padding: 10px 20px;
101
+ background-color: var(--accent-color);
102
+ color: white;
103
+ text-decoration: none;
104
+ border-radius: 5px;
105
+ margin: 0 10px;
106
+ transition: background-color 0.3s;
107
+ }
108
+
109
+ .btn:hover {
110
+ background-color: #357abd;
111
+ }
112
+
113
+ .gallery {
114
+ display: grid;
115
+ grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
116
+ gap: 15px;
117
+ margin-top: 20px;
118
+ }
119
+
120
+ .gallery img {
121
+ width: 100%;
122
+ border-radius: 4px;
123
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
124
+ }
125
+
126
+ footer {
127
+ text-align: center;
128
+ margin-top: 50px;
129
+ color: #888;
130
+ font-size: 0.9em;
131
+ }
132
+
133
+ #last-updated {
134
+ font-size: 0.8em;
135
+ color: #999;
136
+ margin-top: 5px;
137
+ }
138
  </style>
139
  </head>
140
  <body>
141
+ <div class="container">
142
+ <header>
143
+ <h1>CTM Training Dashboard</h1>
144
+ <p>Real-time monitoring of Energy-Based Halting Experiment</p>
145
+ <div id="last-updated">Waiting for data...</div>
146
+ </header>
147
+
148
+ <div class="status-card" id="metrics-container">
149
+ <div class="metric">
150
+ <span class="metric-label">Iteration</span>
151
+ <span class="metric-value" id="iter">--</span>
152
+ </div>
153
+ <div class="metric">
154
+ <span class="metric-label">Epoch</span>
155
+ <span class="metric-value" id="epoch">--</span>
156
+ </div>
157
+ <div class="metric">
158
+ <span class="metric-label">Train Loss</span>
159
+ <span class="metric-value" id="train-loss">--</span>
160
+ </div>
161
+ <div class="metric">
162
+ <span class="metric-label">Test Loss</span>
163
+ <span class="metric-value" id="test-loss">--</span>
164
+ </div>
165
+ <div class="metric">
166
+ <span class="metric-label">Train Acc</span>
167
+ <span class="metric-value" id="train-acc">--</span>
168
+ </div>
169
+ <div class="metric">
170
+ <span class="metric-label">Test Acc</span>
171
+ <span class="metric-value" id="test-acc">--</span>
172
+ </div>
173
+ </div>
174
+
175
+ <div class="plots-grid">
176
+ <div class="plot-card">
177
+ <h3>Loss History</h3>
178
+ <img
179
+ id="loss-plot"
180
+ src="logs/scratch/losses.png"
181
+ alt="Loss Plot"
182
+ onerror="this.src='https://via.placeholder.com/600x400?text=Waiting+for+Plots'"
183
+ />
184
+ </div>
185
+ <div class="plot-card">
186
+ <h3>Accuracy History</h3>
187
+ <img
188
+ id="acc-plot"
189
+ src="logs/scratch/accuracies.png"
190
+ alt="Accuracy Plot"
191
+ onerror="this.src='https://via.placeholder.com/600x400?text=Waiting+for+Plots'"
192
+ />
193
+ </div>
194
+ </div>
195
+
196
+ <div class="artifacts-section">
197
+ <h2>Artifacts & Downloads</h2>
198
+ <p>Download the latest model checkpoints and full logs.</p>
199
+ <a href="logs/scratch/artifacts.zip" class="btn"
200
+ >Download All Artifacts (.zip)</a
201
+ >
202
+ <a href="logs/scratch/checkpoint.pt" class="btn"
203
+ >Download Checkpoint (.pt)</a
204
+ >
205
+ </div>
206
+
207
+ <div class="artifacts-section">
208
+ <h2>Attention Visualization</h2>
209
+ <p>Latest generated attention maps from the model.</p>
210
+ <div class="gallery" id="gif-gallery">
211
+ <!-- GIFs will be injected here -->
212
+ <img
213
+ src="logs/scratch/0_attention.gif"
214
+ onerror="this.style.display='none'"
215
+ alt="Attention Map"
216
+ />
217
+ </div>
218
+ </div>
219
+ </div>
220
+
221
+ <footer>
222
+ <p>Continuous Thought Machine Experiment</p>
223
+ </footer>
224
+
225
+ <script>
226
+ const LOG_DIR = "logs/scratch";
227
+
228
+ async function updateDashboard() {
229
+ try {
230
+ // Fetch status.json
231
+ const response = await fetch(
232
+ `${LOG_DIR}/status.json?t=${new Date().getTime()}`
233
+ );
234
+ if (!response.ok) throw new Error("Status file not found");
235
+
236
+ const data = await response.json();
237
+
238
+ // Update Metrics
239
+ document.getElementById(
240
+ "iter"
241
+ ).textContent = `${data.iteration} / ${data.total_iterations}`;
242
+ document.getElementById("epoch").textContent = data.epoch;
243
+ document.getElementById("train-loss").textContent = parseFloat(
244
+ data.train_loss
245
+ ).toFixed(4);
246
+ document.getElementById("test-loss").textContent = parseFloat(
247
+ data.test_loss
248
+ ).toFixed(4);
249
+
250
+ // Handle Accuracy (could be array or float)
251
+ const formatAcc = (acc) => {
252
+ if (Array.isArray(acc)) {
253
+ return (acc[acc.length - 1] * 100).toFixed(2) + "%";
254
+ }
255
+ return (acc * 100).toFixed(2) + "%";
256
+ };
257
+
258
+ document.getElementById("train-acc").textContent = formatAcc(
259
+ data.train_accuracy
260
+ );
261
+ document.getElementById("test-acc").textContent = formatAcc(
262
+ data.test_accuracy
263
+ );
264
+
265
+ // Update Timestamp
266
+ document.getElementById(
267
+ "last-updated"
268
+ ).textContent = `Last updated: ${new Date().toLocaleTimeString()}`;
269
+
270
+ // Refresh Images
271
+ const timestamp = new Date().getTime();
272
+ document.getElementById(
273
+ "loss-plot"
274
+ ).src = `${LOG_DIR}/losses.png?t=${timestamp}`;
275
+ document.getElementById(
276
+ "acc-plot"
277
+ ).src = `${LOG_DIR}/accuracies.png?t=${timestamp}`;
278
+
279
+ // Refresh Gallery (simple approach: try to reload the known gif)
280
+ const gallery = document.getElementById("gif-gallery");
281
+ gallery.innerHTML = `<img src="${LOG_DIR}/0_attention.gif?t=${timestamp}" onerror="this.style.display='none'" alt="Attention Map">`;
282
+ } catch (error) {
283
+ console.log("Waiting for training to start...", error);
284
+ document.getElementById("last-updated").textContent =
285
+ "Waiting for training to start...";
286
+ }
287
+ }
288
+
289
+ // Update every 30 seconds
290
+ setInterval(updateDashboard, 30000);
291
+
292
+ // Initial call
293
+ updateDashboard();
294
+ </script>
295
  </body>
296
  </html>
tasks/image_classification/train_energy.py CHANGED
@@ -1,6 +1,8 @@
1
  import argparse
2
  import os
3
  import random
 
 
4
 
5
  import matplotlib.pyplot as plt
6
  import numpy as np
@@ -292,7 +294,7 @@ if __name__=='__main__':
292
  elif args.model == 'ff':
293
  model = FFBaseline(
294
  d_model=args.d_model,
295
- d_input=args.d_input,
296
  out_dims=args.out_dims,
297
  dropout=args.dropout,
298
  )
@@ -718,6 +720,27 @@ if __name__=='__main__':
718
 
719
 
720
  # Save model checkpoint (conditional metrics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721
  # Save model checkpoint (conditional metrics)
722
  if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter:
723
  if accelerator.is_main_process:
@@ -744,6 +767,12 @@ if __name__=='__main__':
744
 
745
  accelerator.save(checkpoint_data, f'{args.log_dir}/checkpoint.pt')
746
 
 
 
 
 
 
 
747
  # Push to Hub
748
  if args.push_to_hub and args.hub_model_id:
749
  if bi % (args.save_every * 5) == 0: # Upload less frequently
@@ -753,7 +782,7 @@ if __name__=='__main__':
753
  repo_id=args.hub_model_id,
754
  token=args.hub_token,
755
  commit_message=f"Training checkpoint {bi}",
756
- ignore_patterns=["*.pt"],
757
  )
758
  except Exception as e:
759
  print(f"Failed to upload to hub: {e}")
 
1
  import argparse
2
  import os
3
  import random
4
+ import json
5
+ import shutil
6
 
7
  import matplotlib.pyplot as plt
8
  import numpy as np
 
294
  elif args.model == 'ff':
295
  model = FFBaseline(
296
  d_model=args.d_model,
297
+ backbone_type=args.backbone_type,
298
  out_dims=args.out_dims,
299
  dropout=args.dropout,
300
  )
 
720
 
721
 
722
  # Save model checkpoint (conditional metrics)
723
+ # Save status.json for the dashboard
724
+ if (bi % args.track_every == 0 or bi == args.training_iterations - 1) and bi != start_iter:
725
+ status_data = {
726
+ 'iteration': bi,
727
+ 'total_iterations': args.training_iterations,
728
+ 'epoch': bi // len(trainloader),
729
+ 'train_loss': train_losses[-1] if train_losses else 0.0,
730
+ 'test_loss': test_losses[-1] if test_losses else 0.0,
731
+ 'train_accuracy': train_accuracies[-1] if train_accuracies else 0.0, # Might be array for CTM
732
+ 'test_accuracy': test_accuracies[-1] if test_accuracies else 0.0, # Might be array for CTM
733
+ 'learning_rate': current_lr,
734
+ }
735
+ # Handle numpy arrays for JSON serialization
736
+ def convert_to_serializable(obj):
737
+ if isinstance(obj, np.ndarray):
738
+ return obj.tolist()
739
+ return obj
740
+
741
+ with open(f'{args.log_dir}/status.json', 'w') as f:
742
+ json.dump(status_data, f, default=convert_to_serializable)
743
+
744
  # Save model checkpoint (conditional metrics)
745
  if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter:
746
  if accelerator.is_main_process:
 
767
 
768
  accelerator.save(checkpoint_data, f'{args.log_dir}/checkpoint.pt')
769
 
770
+ # Zip artifacts
771
+ try:
772
+ shutil.make_archive(f'{args.log_dir}/artifacts', 'zip', args.log_dir)
773
+ except Exception as e:
774
+ print(f"Failed to zip artifacts: {e}")
775
+
776
  # Push to Hub
777
  if args.push_to_hub and args.hub_model_id:
778
  if bi % (args.save_every * 5) == 0: # Upload less frequently
 
782
  repo_id=args.hub_model_id,
783
  token=args.hub_token,
784
  commit_message=f"Training checkpoint {bi}",
785
+ ignore_patterns=[], # Upload everything including .pt and .zip
786
  )
787
  except Exception as e:
788
  print(f"Failed to upload to hub: {e}")
verify_dashboard.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import subprocess
4
+ import time
5
+ import shutil
6
+
7
+ def verify():
8
+ print("Starting verification...")
9
+
10
+ # Clean up previous logs
11
+ if os.path.exists('logs/scratch'):
12
+ shutil.rmtree('logs/scratch')
13
+
14
+ # Run training script for a few iterations
15
+ # We use a small model (ff) and cifar10 for speed, with minimal iterations
16
+ cmd = [
17
+ "pixi", "run", "accelerate", "launch", "--cpu", "tasks/image_classification/train_energy.py",
18
+ "--model", "ff",
19
+ "--dataset", "cifar10",
20
+ "--batch_size", "4",
21
+ "--training_iterations", "5", # Run for 5 iterations
22
+ "--track_every", "2", # Track every 2 iterations to ensure we get logs
23
+ "--save_every", "2", # Save every 2 iterations
24
+ "--log_dir", "logs/scratch",
25
+ "--device", "-1" # Use CPU for verification to avoid GPU issues if any
26
+ ]
27
+
28
+ print(f"Running command: {' '.join(cmd)}")
29
+ try:
30
+ subprocess.run(cmd, check=True, capture_output=True)
31
+ except subprocess.CalledProcessError as e:
32
+ print("Training failed!")
33
+ print(e.stderr.decode())
34
+ return
35
+
36
+ print("Training finished. Checking files...")
37
+
38
+ # Check status.json
39
+ if os.path.exists('logs/scratch/status.json'):
40
+ print("[PASS] status.json exists")
41
+ with open('logs/scratch/status.json', 'r') as f:
42
+ data = json.load(f)
43
+ print(f" - Iteration: {data.get('iteration')}")
44
+ print(f" - Train Loss: {data.get('train_loss')}")
45
+ else:
46
+ print("[FAIL] status.json missing")
47
+
48
+ # Check artifacts.zip
49
+ if os.path.exists('logs/scratch/artifacts.zip'):
50
+ print("[PASS] artifacts.zip exists")
51
+ else:
52
+ print("[FAIL] artifacts.zip missing")
53
+
54
+ # Check plots
55
+ if os.path.exists('logs/scratch/losses.png'):
56
+ print("[PASS] losses.png exists")
57
+ else:
58
+ print("[FAIL] losses.png missing")
59
+
60
+ if os.path.exists('logs/scratch/accuracies.png'):
61
+ print("[PASS] accuracies.png exists")
62
+ else:
63
+ print("[FAIL] accuracies.png missing")
64
+
65
+ # Check index.html content (simple check)
66
+ if os.path.exists('index.html'):
67
+ with open('index.html', 'r') as f:
68
+ content = f.read()
69
+ if 'CTM Training Dashboard' in content and 'status.json' in content:
70
+ print("[PASS] index.html looks correct")
71
+ else:
72
+ print("[FAIL] index.html content incorrect")
73
+ else:
74
+ print("[FAIL] index.html missing")
75
+
76
+ if __name__ == "__main__":
77
+ verify()