LukeDarlow commited on
Commit
7ae2b6b
·
2 Parent(s): 5e7a518 a25053b

Merge pull request #19 from kevin/main

Browse files

(MNIST Tutorial) Fixing iterator usage resulting in slow training

Files changed (1) hide show
  1. examples/01_mnist.ipynb +6 -2
examples/01_mnist.ipynb CHANGED
@@ -578,7 +578,7 @@
578
  "def train(model, trainloader, testloader, iterations, test_every, device):\n",
579
  "\n",
580
  " optimizer = torch.optim.AdamW(params=list(model.parameters()), lr=0.0001, eps=1e-8)\n",
581
- "\n",
582
  " model.train()\n",
583
  " \n",
584
  " train_losses = []\n",
@@ -595,7 +595,11 @@
595
  " test_accuracy = None\n",
596
  " for stepi in range(iterations):\n",
597
  "\n",
598
- " inputs, targets = next(iter(trainloader))\n",
 
 
 
 
599
  " inputs, targets = inputs.to(device), targets.to(device)\n",
600
  " predictions, certainties, _ = model(inputs, track=False)\n",
601
  " train_loss, where_most_certain = get_loss(predictions, certainties, targets)\n",
 
578
  "def train(model, trainloader, testloader, iterations, test_every, device):\n",
579
  "\n",
580
  " optimizer = torch.optim.AdamW(params=list(model.parameters()), lr=0.0001, eps=1e-8)\n",
581
+ " iterator = iter(trainloader)\n",
582
  " model.train()\n",
583
  " \n",
584
  " train_losses = []\n",
 
595
  " test_accuracy = None\n",
596
  " for stepi in range(iterations):\n",
597
  "\n",
598
+ " try:\n",
599
+ " inputs, targets = next(iterator)\n",
600
+ " except StopIteration:\n",
601
+ " iterator = iter(trainloader)\n",
602
+ " inputs, targets = next(iterator)\n",
603
  " inputs, targets = inputs.to(device), targets.to(device)\n",
604
  " predictions, certainties, _ = model(inputs, track=False)\n",
605
  " train_loss, where_most_certain = get_loss(predictions, certainties, targets)\n",