Merge pull request #19 from kevin/main
Browse files(MNIST Tutorial) Fixing iterator usage resulting in slow training
- examples/01_mnist.ipynb +6 -2
examples/01_mnist.ipynb
CHANGED
|
@@ -578,7 +578,7 @@
|
|
| 578 |
"def train(model, trainloader, testloader, iterations, test_every, device):\n",
|
| 579 |
"\n",
|
| 580 |
" optimizer = torch.optim.AdamW(params=list(model.parameters()), lr=0.0001, eps=1e-8)\n",
|
| 581 |
-
"\n",
|
| 582 |
" model.train()\n",
|
| 583 |
" \n",
|
| 584 |
" train_losses = []\n",
|
|
@@ -595,7 +595,11 @@
|
|
| 595 |
" test_accuracy = None\n",
|
| 596 |
" for stepi in range(iterations):\n",
|
| 597 |
"\n",
|
| 598 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 599 |
" inputs, targets = inputs.to(device), targets.to(device)\n",
|
| 600 |
" predictions, certainties, _ = model(inputs, track=False)\n",
|
| 601 |
" train_loss, where_most_certain = get_loss(predictions, certainties, targets)\n",
|
|
|
|
| 578 |
"def train(model, trainloader, testloader, iterations, test_every, device):\n",
|
| 579 |
"\n",
|
| 580 |
" optimizer = torch.optim.AdamW(params=list(model.parameters()), lr=0.0001, eps=1e-8)\n",
|
| 581 |
+
" iterator = iter(trainloader)\n",
|
| 582 |
" model.train()\n",
|
| 583 |
" \n",
|
| 584 |
" train_losses = []\n",
|
|
|
|
| 595 |
" test_accuracy = None\n",
|
| 596 |
" for stepi in range(iterations):\n",
|
| 597 |
"\n",
|
| 598 |
+
" try:\n",
|
| 599 |
+
" inputs, targets = next(iterator)\n",
|
| 600 |
+
" except StopIteration:\n",
|
| 601 |
+
" iterator = iter(trainloader)\n",
|
| 602 |
+
" inputs, targets = next(iterator)\n",
|
| 603 |
" inputs, targets = inputs.to(device), targets.to(device)\n",
|
| 604 |
" predictions, certainties, _ = model(inputs, track=False)\n",
|
| 605 |
" train_loss, where_most_certain = get_loss(predictions, certainties, targets)\n",
|