RFTSystems commited on
Commit
24c834f
·
verified ·
1 Parent(s): 1faeebc

Update train_dclr_model.py

Browse files
Files changed (1) hide show
  1. train_dclr_model.py +82 -202
train_dclr_model.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
@@ -6,7 +5,7 @@ import torchvision
6
  import torchvision.transforms as transforms
7
  from torch.utils.data import DataLoader
8
  import matplotlib.pyplot as plt
9
- from datetime import datetime
10
 
11
  # Import the DCLR optimizer from the local file
12
  from dclr_optimizer import DCLR
@@ -28,211 +27,92 @@ class SimpleCNN(nn.Module):
28
  x = F.relu(self.fc1(x))
29
  return self.fc2(x)
30
 
31
- # === Self-contained Lion optimizer (no external dependency) ===
32
- class Lion(torch.optim.Optimizer):
33
- """
34
- Minimal Lion optimizer implementation (Chen et al., 2023).
35
- Uses sign of momentum with weight decay. Works for standard use-cases.
36
- """
37
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), weight_decay=0.0):
38
- defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
39
- super().__init__(params, defaults)
40
-
41
- @torch.no_grad()
42
- def step(self):
43
- for group in self.param_groups:
44
- lr = group['lr']
45
- beta1, beta2 = group['betas']
46
- wd = group['weight_decay']
47
-
48
- for p in group['params']:
49
- if p.grad is None:
50
- continue
51
- grad = p.grad
52
-
53
- # weight decay
54
- if wd != 0:
55
- grad = grad.add(p, alpha=wd)
56
-
57
- state = self.state[p]
58
- if len(state) == 0:
59
- state['exp_avg'] = torch.zeros_like(p)
60
-
61
- exp_avg = state['exp_avg']
62
- # Update momentum
63
- exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
64
-
65
- # Parameter update: sign of momentum + sign of gradient blend
66
- update = exp_avg.mul(beta1).add(grad, alpha=1 - beta1)
67
- p.add_(torch.sign(update), alpha=-lr)
68
-
69
  # === CIFAR-10 Data Loading ===
70
- transform_train = transforms.Compose([
71
- transforms.RandomCrop(32, padding=4),
72
- transforms.RandomHorizontalFlip(),
73
  transforms.ToTensor(),
74
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
75
  ])
76
 
77
- transform_test = transforms.Compose([
78
- transforms.ToTensor(),
79
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
80
- ])
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
83
- train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
84
-
85
- test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
86
- test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)
87
-
88
- # === Utility: Train and evaluate with a given optimizer ===
89
- def train_and_evaluate(optimizer_name, optimizer_ctor, optimizer_kwargs, epochs=20, save_prefix=""):
90
- model = SimpleCNN()
91
- criterion = nn.CrossEntropyLoss()
92
- optimizer = optimizer_ctor(model.parameters(), **optimizer_kwargs)
93
-
94
- losses = []
95
- accs = []
96
-
97
- print(f"Starting training [{optimizer_name}] for {epochs} epochs...")
98
- for epoch in range(epochs):
99
- model.train()
100
- running_loss = 0.0
101
- correct = 0
102
- total = 0
103
- for inputs, labels in train_loader:
104
- optimizer.zero_grad()
105
- outputs = model(inputs)
106
- loss = criterion(outputs, labels)
107
- loss.backward()
108
-
109
- # DCLR requires output_activations argument
110
- if optimizer_name.lower() == "dclr":
111
- if hasattr(optimizer, "step"):
112
- optimizer.step(output_activations=outputs)
113
- else:
114
- raise RuntimeError("DCLR optimizer missing step(output_activations=...)")
115
- else:
116
- optimizer.step()
117
-
118
- running_loss += loss.item()
119
- _, predicted = outputs.max(1)
120
- total += labels.size(0)
121
- correct += predicted.eq(labels).sum().item()
122
-
123
- epoch_loss = running_loss / len(train_loader)
124
- epoch_acc = 100.0 * correct / total
125
- losses.append(epoch_loss)
126
- accs.append(epoch_acc)
127
- print(f"[{optimizer_name}] Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.2f}%")
128
-
129
- print(f"Training complete for [{optimizer_name}]. Evaluating on test set...")
130
- model.eval()
131
  correct = 0
132
  total = 0
133
- with torch.no_grad():
134
- for inputs, labels in test_loader:
135
- outputs = model(inputs)
136
- _, predicted = outputs.max(1)
137
- total += labels.size(0)
138
- correct += predicted.eq(labels).sum().item()
139
-
140
- test_acc = 100.0 * correct / total
141
- print(f"[{optimizer_name}] Final Test Accuracy: {test_acc:.2f}%")
142
-
143
- # Save artifacts with optimizer-specific names
144
- if save_prefix == "":
145
- save_prefix = optimizer_name.lower()
146
-
147
- # Model weights
148
- weights_path = f"{save_prefix}_simple_cnn.pth"
149
- torch.save(model.state_dict(), weights_path)
150
- print(f"[{optimizer_name}] Model saved to {weights_path}")
151
-
152
- # Training performance plot
153
- plt.figure()
154
- plt.plot(range(1, epochs+1), losses, label='Loss')
155
- plt.plot(range(1, epochs+1), accs, label='Accuracy')
156
- plt.xlabel('Epoch')
157
- plt.ylabel('Value')
158
- plt.legend()
159
- plt.title(f'Training Performance on CIFAR-10 ({optimizer_name})')
160
- perf_path = f"{save_prefix}_training_performance.png"
161
- plt.savefig(perf_path)
162
- print(f"[{optimizer_name}] Training performance plot saved to {perf_path}")
163
-
164
- # Final test accuracy plot
165
- plt.figure()
166
- plt.bar([optimizer_name], [test_acc])
167
- plt.ylabel('Accuracy (%)')
168
- plt.title(f'Final Test Accuracy ({optimizer_name})')
169
- acc_plot_path = f"{save_prefix}_final_test_accuracy.png"
170
- plt.savefig(acc_plot_path)
171
- print(f"[{optimizer_name}] Final test accuracy plot saved to {acc_plot_path}")
172
-
173
- # Final test accuracy number
174
- acc_txt_path = f"{save_prefix}_final_test_accuracy.txt"
175
- with open(acc_txt_path, "w") as f:
176
- f.write(f"{test_acc:.2f}")
177
- print(f"[{optimizer_name}] Final test accuracy saved to {acc_txt_path}")
178
-
179
- return {
180
- "optimizer": optimizer_name,
181
- "test_acc": test_acc,
182
- "weights_path": weights_path,
183
- "perf_plot_path": perf_path,
184
- "acc_plot_path": acc_plot_path,
185
- "acc_txt_path": acc_txt_path,
186
- "losses": losses,
187
- "accs": accs,
188
- }
189
-
190
- # === Run benchmarks for DCLR vs Adam vs Lion ===
191
- def main():
192
- os.makedirs("artifacts", exist_ok=True)
193
- os.chdir("artifacts") # keep outputs organized
194
-
195
- epochs = 20
196
-
197
- # DCLR (using your tuned hyperparams)
198
- dclr_results = train_and_evaluate(
199
- optimizer_name="DCLR",
200
- optimizer_ctor=lambda params, lr, lambda_, verbose=False: DCLR(params, lr=lr, lambda_=lambda_, verbose=verbose),
201
- optimizer_kwargs={"lr": 0.1, "lambda_": 0.1, "verbose": False},
202
- epochs=epochs,
203
- save_prefix="dclr"
204
- )
205
-
206
- # Adam
207
- adam_results = train_and_evaluate(
208
- optimizer_name="Adam",
209
- optimizer_ctor=lambda params, lr: torch.optim.Adam(params, lr=lr),
210
- optimizer_kwargs={"lr": 0.001},
211
- epochs=epochs,
212
- save_prefix="adam"
213
- )
214
-
215
- # Lion
216
- lion_results = train_and_evaluate(
217
- optimizer_name="Lion",
218
- optimizer_ctor=lambda params, lr, betas, weight_decay: Lion(params, lr=lr, betas=betas, weight_decay=weight_decay),
219
- optimizer_kwargs={"lr": 0.001, "betas": (0.9, 0.99), "weight_decay": 0.0},
220
- epochs=epochs,
221
- save_prefix="lion"
222
- )
223
-
224
- # Combined benchmark ledger
225
- ledger_path = "benchmark_results.txt"
226
- with open(ledger_path, "w") as f:
227
- f.write(f"Run timestamp: {datetime.utcnow().isoformat()}Z\n")
228
- f.write(f"DCLR: {dclr_results['test_acc']:.2f}%\n")
229
- f.write(f"Adam: {adam_results['test_acc']:.2f}%\n")
230
- f.write(f"Lion: {lion_results['test_acc']:.2f}%\n")
231
- print(f"Benchmark results saved to {ledger_path}")
232
-
233
- # Symlink or copy DCLR artifacts to legacy names for existing app (optional)
234
- # If your current app expects specific filenames at repo root, you can create copies:
235
- # For a clean setup, prefer reading from artifacts/ in app.py.
236
-
237
- if __name__ == "__main__":
238
- main()
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
5
  import torchvision.transforms as transforms
6
  from torch.utils.data import DataLoader
7
  import matplotlib.pyplot as plt
8
+ import os
9
 
10
  # Import the DCLR optimizer from the local file
11
  from dclr_optimizer import DCLR
 
27
  x = F.relu(self.fc1(x))
28
  return self.fc2(x)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # === CIFAR-10 Data Loading ===
31
+ transform = transforms.Compose([
 
 
32
  transforms.ToTensor(),
33
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
34
  ])
35
 
36
+ train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
37
+ train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
38
+
39
+ test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
40
+ test_loader = DataLoader(test_set, batch_size=128, shuffle=False)
41
+
42
+ # === Training Configuration ===
43
+ model = SimpleCNN()
44
+ optimizer = DCLR(model.parameters(), lr=0.1, lambda_=0.1, verbose=False)
45
+ criterion = nn.CrossEntropyLoss()
46
+ epochs = 20
47
+
48
+ print(f"Starting training with DCLR for {epochs} epochs...")
49
+
50
+ losses, accs = [], []
51
 
52
+ # === Training Loop ===
53
+ for epoch in range(epochs):
54
+ model.train()
55
+ running_loss = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  correct = 0
57
  total = 0
58
+ for inputs, labels in train_loader:
59
+ optimizer.zero_grad()
60
+ outputs = model(inputs)
61
+ loss = criterion(outputs, labels)
62
+ loss.backward()
63
+ optimizer.step(output_activations=outputs)
64
+
65
+ running_loss += loss.item()
66
+ _, predicted = outputs.max(1)
67
+ total += labels.size(0)
68
+ correct += predicted.eq(labels).sum().item()
69
+
70
+ epoch_loss = running_loss / len(train_loader)
71
+ epoch_acc = 100.0 * correct / total
72
+ losses.append(epoch_loss)
73
+ accs.append(epoch_acc)
74
+ print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")
75
+
76
+ print("Training complete.")
77
+
78
+ # === Evaluate on Test Set ===
79
+ model.eval()
80
+ correct = 0
81
+ total = 0
82
+ with torch.no_grad():
83
+ for inputs, labels in test_loader:
84
+ outputs = model(inputs)
85
+ _, predicted = outputs.max(1)
86
+ total += labels.size(0)
87
+ correct += predicted.eq(labels).sum().item()
88
+
89
+ test_acc = 100.0 * correct / total
90
+ print(f"Final Test Accuracy: {test_acc:.2f}%")
91
+
92
+ # === Save the Trained Model ===
93
+ torch.save(model.state_dict(), 'simple_cnn_dclr_tuned.pth')
94
+ print("Model saved to simple_cnn_dclr_tuned.pth")
95
+
96
+ # === Save Training Performance Plot ===
97
+ plt.figure()
98
+ plt.plot(range(1, epochs+1), losses, label='Loss')
99
+ plt.plot(range(1, epochs+1), accs, label='Accuracy')
100
+ plt.xlabel('Epoch')
101
+ plt.ylabel('Value')
102
+ plt.legend()
103
+ plt.title('Training Performance on CIFAR-10 (DCLR)')
104
+ plt.savefig('training_performance.png')
105
+ print("Training performance plot saved to training_performance.png")
106
+
107
+ # === Save Final Test Accuracy Plot ===
108
+ plt.figure()
109
+ plt.bar(['CIFAR-10'], [test_acc])
110
+ plt.ylabel('Accuracy (%)')
111
+ plt.title('Final Test Accuracy (DCLR)')
112
+ plt.savefig('final_test_accuracy.png')
113
+ print("Final test accuracy plot saved to final_test_accuracy.png")
114
+
115
+ # === Save Final Test Accuracy Number ===
116
+ with open("final_test_accuracy.txt", "w") as f:
117
+ f.write(f"{test_acc:.2f}")
118
+ print("Final test accuracy saved to final_test_accuracy.txt")