Spaces:
Sleeping
Sleeping
Commit
·
987f571
1
Parent(s):
7dc7452
feature: Added manual training and PyTorch Lightning training loops
Browse files- src/{main.py → train.py} +26 -5
src/{main.py → train.py}
RENAMED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
from torch import nn, optim
|
| 2 |
from torch.utils.data import random_split
|
| 3 |
import pytorch_lightning as pl
|
|
@@ -22,13 +24,32 @@ def main():
|
|
| 22 |
validate_dataloader = DataLoader(validate_data, num_workers=2)
|
| 23 |
test_dataloader = DataLoader(test_data, num_workers=8) # My CPU has 8 cores
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
trainer.test(model=pl_net, dataloaders=test_dataloader)
|
| 33 |
|
| 34 |
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
from torch import nn, optim
|
| 4 |
from torch.utils.data import random_split
|
| 5 |
import pytorch_lightning as pl
|
|
|
|
| 24 |
validate_dataloader = DataLoader(validate_data, num_workers=2)
|
| 25 |
test_dataloader = DataLoader(test_data, num_workers=8) # My CPU has 8 cores
|
| 26 |
|
| 27 |
+
net = CNN(input_channels=1, num_classes=10).to("cuda")
|
| 28 |
+
opt = optim.Adam(net.parameters(), lr=1e-4)
|
| 29 |
+
loss_fn = nn.CrossEntropyLoss()
|
| 30 |
+
max_epochs = 10
|
| 31 |
+
for i in range(max_epochs):
|
| 32 |
+
for idx, batch in enumerate(train_dataloader):
|
| 33 |
+
x, y = batch
|
| 34 |
+
x = x.to("cuda")
|
| 35 |
+
y = y.to("cuda")
|
| 36 |
+
|
| 37 |
+
y_pred = net(x).reshape(1, -1)
|
| 38 |
+
loss = loss_fn(y_pred, y)
|
| 39 |
+
|
| 40 |
+
opt.zero_grad()
|
| 41 |
+
loss.backward()
|
| 42 |
+
opt.step()
|
| 43 |
|
| 44 |
+
if idx % 1000 == 0:
|
| 45 |
+
print(f"Loss: {loss.item()} ({idx} / {len(train_dataloader)})")
|
| 46 |
|
| 47 |
+
torch.save(net, "../checkpoints/pytorch/version_1.pt")
|
| 48 |
+
|
| 49 |
+
# grayscale channels = 1, mnist num_labels = 10
|
| 50 |
+
trainer = pl.Trainer(limit_train_batches=100, max_epochs=10, default_root_dir="../checkpoints")
|
| 51 |
+
pl_net = LitTrainer(CNN(input_channels=1, num_classes=10))
|
| 52 |
+
trainer.fit(pl_net, train_dataloader, validate_dataloader)
|
| 53 |
trainer.test(model=pl_net, dataloaders=test_dataloader)
|
| 54 |
|
| 55 |
|