RFTSystems commited on
Commit
ceee644
·
verified ·
1 Parent(s): bb44a6f

Update train_dclr_model.py

Browse files
Files changed (1) hide show
  1. train_dclr_model.py +75 -0
train_dclr_model.py CHANGED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+ import torchvision.transforms as transforms
6
+ from torch.utils.data import DataLoader
7
+
8
+ # Import the DCLR optimizer from the local file
9
+ from dclr_optimizer import DCLR
10
+
11
+ # === Simple CNN Model Definition ===
12
+ class SimpleCNN(nn.Module):
13
+ def __init__(self):
14
+ super(SimpleCNN, self).__init__()
15
+ self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
16
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
17
+ self.pool = nn.MaxPool2d(2, 2)
18
+ self.fc1 = nn.Linear(64 * 8 * 8, 512)
19
+ self.fc2 = nn.Linear(512, 10)
20
+
21
+ def forward(self, x):
22
+ x = self.pool(F.relu(self.conv1(x)))
23
+ x = self.pool(F.relu(self.conv2(x)))
24
+ x = x.view(-1, 64 * 8 * 8)
25
+ x = F.relu(self.fc1(x))
26
+ return self.fc2(x)
27
+
28
+ # === CIFAR-10 Data Loading ===
29
+ transform = transforms.Compose([transforms.ToTensor()])
30
+
31
+ train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
32
+ train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
33
+
34
+ # === Training Configuration ===
35
+ model = SimpleCNN()
36
+
37
+ # Instantiate DCLR with best-tuned hyperparameters
38
+ best_lr = 0.1
39
+ best_lambda = 0.1
40
+ optimizer = DCLR(model.parameters(), lr=best_lr, lambda_=best_lambda, verbose=False)
41
+
42
+ criterion = nn.CrossEntropyLoss()
43
+ extended_epochs = 20
44
+
45
+ print(f"Starting training for SimpleCNN with DCLR (lr={best_lr}, lambda_={best_lambda}) for {extended_epochs} epochs...")
46
+
47
+ # === Training Loop ===
48
+ for epoch in range(extended_epochs):
49
+ model.train()
50
+ running_loss = 0.0
51
+ correct = 0
52
+ total = 0
53
+ for batch_idx, (inputs, labels) in enumerate(train_loader):
54
+ optimizer.zero_grad()
55
+ outputs = model(inputs)
56
+ loss = criterion(outputs, labels)
57
+ loss.backward()
58
+
59
+ # DCLR requires output_activations for its step method
60
+ optimizer.step(output_activations=outputs)
61
+
62
+ running_loss += loss.item()
63
+ _, predicted = outputs.max(1)
64
+ total += labels.size(0)
65
+ correct += predicted.eq(labels).sum().item()
66
+
67
+ epoch_loss = running_loss / len(train_loader)
68
+ epoch_acc = 100.0 * correct / total
69
+ print(f"Epoch {epoch+1}/{extended_epochs} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")
70
+
71
+ print("Training complete.")
72
+
73
+ # === Save the Trained Model ===
74
+ torch.save(model.state_dict(), 'simple_cnn_dclr_tuned.pth')
75
+ print("Model saved to simple_cnn_dclr_tuned.pth")