From 28e0e441b7174107361de17bb580d37663dc0854 Mon Sep 17 00:00:00 2001 From: danblae <43992405+danblae@users.noreply.github.com> Date: Mon, 14 Nov 2022 09:56:48 +0000 Subject: [PATCH] iter.next() to next(iter) The new pytorch version does not allow the old command and causes the code to return errors when running. It is a small fix, that will especially help beginners. --- 09_dataloader.py | 4 ++-- 13_feedforward.py | 2 +- 14_cnn.py | 2 +- 16_tensorboard.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/09_dataloader.py b/09_dataloader.py index 18c76e01bd0814..88247da95a16b3 100644 --- a/09_dataloader.py +++ b/09_dataloader.py @@ -66,7 +66,7 @@ def __len__(self): # convert to an iterator and look at one random sample dataiter = iter(train_loader) -data = dataiter.next() +data = next(dataiter) features, labels = data print(features, labels) @@ -97,6 +97,6 @@ def __len__(self): # look at one random sample dataiter = iter(train_loader) -data = dataiter.next() +data = next(dataiter) inputs, targets = data print(inputs.shape, targets.shape) diff --git a/13_feedforward.py b/13_feedforward.py index b38f9eb7c28261..29507b28edbf06 100644 --- a/13_feedforward.py +++ b/13_feedforward.py @@ -35,7 +35,7 @@ shuffle=False) examples = iter(test_loader) -example_data, example_targets = examples.next() +example_data, example_targets = next(examples) for i in range(6): plt.subplot(2,3,i+1) diff --git a/14_cnn.py b/14_cnn.py index ca0b4cbf7bb4bb..a4abbd1a3a21b7 100644 --- a/14_cnn.py +++ b/14_cnn.py @@ -45,7 +45,7 @@ def imshow(img): # get some random training images dataiter = iter(train_loader) -images, labels = dataiter.next() +images, labels = next(dataiter) # show images imshow(torchvision.utils.make_grid(images)) diff --git a/16_tensorboard.py b/16_tensorboard.py index 0bd67dc6f8bd4b..21baa49bd304f5 100644 --- a/16_tensorboard.py +++ b/16_tensorboard.py @@ -43,7 +43,7 @@ shuffle=False) examples = iter(test_loader) -example_data, example_targets = examples.next() +example_data, example_targets = next(examples) for i in range(6): plt.subplot(2,3,i+1)