RFTSystems commited on
Commit
d137713
·
verified ·
1 Parent(s): 5f795a8

Update train_dclr_model.py

Browse files
Files changed (1) hide show
  1. train_dclr_model.py +46 -3
train_dclr_model.py CHANGED
@@ -4,6 +4,7 @@ import torch.nn.functional as F
4
  import torchvision
5
  import torchvision.transforms as transforms
6
  from torch.utils.data import DataLoader
 
7
 
8
  # Import the DCLR optimizer from the local file
9
  from dclr_optimizer import DCLR
@@ -26,15 +27,20 @@ class SimpleCNN(nn.Module):
26
  return self.fc2(x)
27
 
28
  # === CIFAR-10 Data Loading ===
29
- transform = transforms.Compose([transforms.ToTensor()])
 
 
 
30
 
31
  train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
32
  train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
33
 
 
 
 
34
  # === Training Configuration ===
35
  model = SimpleCNN()
36
 
37
- # Instantiate DCLR with best-tuned hyperparameters
38
  best_lr = 0.1
39
  best_lambda = 0.1
40
  optimizer = DCLR(model.parameters(), lr=best_lr, lambda_=best_lambda, verbose=False)
@@ -44,6 +50,8 @@ extended_epochs = 20
44
 
45
  print(f"Starting training for SimpleCNN with DCLR (lr={best_lr}, lambda_={best_lambda}) for {extended_epochs} epochs...")
46
 
 
 
47
  # === Training Loop ===
48
  for epoch in range(extended_epochs):
49
  model.train()
@@ -66,10 +74,45 @@ for epoch in range(extended_epochs):
66
 
67
  epoch_loss = running_loss / len(train_loader)
68
  epoch_acc = 100.0 * correct / total
 
 
69
  print(f"Epoch {epoch+1}/{extended_epochs} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")
70
 
71
  print("Training complete.")
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  # === Save the Trained Model ===
74
  torch.save(model.state_dict(), 'simple_cnn_dclr_tuned.pth')
75
- print("Model saved to simple_cnn_dclr_tuned.pth")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import torchvision
5
  import torchvision.transforms as transforms
6
  from torch.utils.data import DataLoader
7
+ import matplotlib.pyplot as plt
8
 
9
  # Import the DCLR optimizer from the local file
10
  from dclr_optimizer import DCLR
 
27
  return self.fc2(x)
28
 
29
  # === CIFAR-10 Data Loading ===
30
+ transform = transforms.Compose([
31
+ transforms.ToTensor(),
32
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
33
+ ])
34
 
35
  train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
36
  train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
37
 
38
+ test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
39
+ test_loader = DataLoader(test_set, batch_size=128, shuffle=False)
40
+
41
  # === Training Configuration ===
42
  model = SimpleCNN()
43
 
 
44
  best_lr = 0.1
45
  best_lambda = 0.1
46
  optimizer = DCLR(model.parameters(), lr=best_lr, lambda_=best_lambda, verbose=False)
 
50
 
51
  print(f"Starting training for SimpleCNN with DCLR (lr={best_lr}, lambda_={best_lambda}) for {extended_epochs} epochs...")
52
 
53
+ losses, accs = [], []
54
+
55
  # === Training Loop ===
56
  for epoch in range(extended_epochs):
57
  model.train()
 
74
 
75
  epoch_loss = running_loss / len(train_loader)
76
  epoch_acc = 100.0 * correct / total
77
+ losses.append(epoch_loss)
78
+ accs.append(epoch_acc)
79
  print(f"Epoch {epoch+1}/{extended_epochs} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")
80
 
81
  print("Training complete.")
82
 
83
+ # === Evaluate on Test Set ===
84
+ model.eval()
85
+ correct = 0
86
+ total = 0
87
+ with torch.no_grad():
88
+ for inputs, labels in test_loader:
89
+ outputs = model(inputs)
90
+ _, predicted = outputs.max(1)
91
+ total += labels.size(0)
92
+ correct += predicted.eq(labels).sum().item()
93
+
94
+ test_acc = 100.0 * correct / total
95
+ print(f"Final Test Accuracy: {test_acc:.2f}%")
96
+
97
  # === Save the Trained Model ===
98
  torch.save(model.state_dict(), 'simple_cnn_dclr_tuned.pth')
99
+ print("Model saved to simple_cnn_dclr_tuned.pth")
100
+
101
+ # === Save Training Performance Plot ===
102
+ plt.figure()
103
+ plt.plot(range(1, extended_epochs+1), losses, label='Loss')
104
+ plt.plot(range(1, extended_epochs+1), accs, label='Accuracy')
105
+ plt.xlabel('Epoch')
106
+ plt.ylabel('Value')
107
+ plt.legend()
108
+ plt.title('Training Performance on CIFAR-10')
109
+ plt.savefig('training_performance.png')
110
+ print("Training performance plot saved to training_performance.png")
111
+
112
+ # === Save Final Test Accuracy Plot ===
113
+ plt.figure()
114
+ plt.bar(['CIFAR-10'], [test_acc])
115
+ plt.ylabel('Accuracy (%)')
116
+ plt.title('Final Test Accuracy')
117
+ plt.savefig('final_test_accuracy.png')
118
+ print("Final test accuracy plot saved to final_test_accuracy.png")