From a08a6903239476a60c8f8ac8ebe325029a15aeab Mon Sep 17 00:00:00 2001 From: hecoding Date: Tue, 3 Nov 2020 16:06:39 +0100 Subject: [PATCH 1/2] bugfix: batch_size for MNISTDataModule --- pl_bolts/datamodules/mnist_datamodule.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 7229462e97..220f8cb26e 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -76,6 +76,7 @@ def __init__( self.num_workers = num_workers self.normalize = normalize self.seed = seed + self.batch_size = batch_size @property def num_classes(self): @@ -92,12 +93,11 @@ def prepare_data(self): MNIST(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor()) MNIST(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor()) - def train_dataloader(self, batch_size=32, transforms=None): + def train_dataloader(self, transforms=None): """ MNIST train set removes a subset to use for validation Args: - batch_size: size of batch transforms: custom transforms """ transforms = transforms or self.train_transforms or self._default_transforms() @@ -109,7 +109,7 @@ def train_dataloader(self, batch_size=32, transforms=None): ) loader = DataLoader( dataset_train, - batch_size=batch_size, + batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, @@ -117,12 +117,11 @@ def train_dataloader(self, batch_size=32, transforms=None): ) return loader - def val_dataloader(self, batch_size=32, transforms=None): + def val_dataloader(self, transforms=None): """ MNIST val set uses a subset of the training set for validation Args: - batch_size: size of batch transforms: custom transforms """ transforms = transforms or self.val_transforms or self._default_transforms() @@ -133,7 +132,7 @@ def val_dataloader(self, batch_size=32, transforms=None): ) loader = DataLoader( dataset_val, - batch_size=batch_size, + batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, @@ -141,19 +140,19 @@ def val_dataloader(self, batch_size=32, transforms=None): ) return loader - def test_dataloader(self, batch_size=32, transforms=None): + def test_dataloader(self, transforms=None): """ MNIST test set uses the test split Args: - batch_size: size of batch transforms: custom transforms """ transforms = transforms or self.val_transforms or self._default_transforms() dataset = MNIST(self.data_dir, train=False, download=False, transform=transforms) loader = DataLoader( - dataset, batch_size=batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, pin_memory=True + dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, + pin_memory=True ) return loader From 580e0082e2c6209a7877fbbd445bca6e564ee034 Mon Sep 17 00:00:00 2001 From: hecoding Date: Thu, 5 Nov 2020 15:54:36 +0100 Subject: [PATCH 2/2] fix MNISTDataModule *_dataloader() signatures --- pl_bolts/datamodules/mnist_datamodule.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 220f8cb26e..a741203167 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -93,14 +93,14 @@ def prepare_data(self): MNIST(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor()) MNIST(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor()) - def train_dataloader(self, transforms=None): + def train_dataloader(self): """ MNIST train set removes a subset to use for validation Args: transforms: custom transforms """ - transforms = transforms or self.train_transforms or self._default_transforms() + transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms dataset = MNIST(self.data_dir, train=True, download=False, transform=transforms) train_length = len(dataset) @@ -117,14 +117,14 @@ def train_dataloader(self, transforms=None): ) return loader - def val_dataloader(self, transforms=None): + def val_dataloader(self): """ MNIST val set uses a subset of the training set for validation Args: transforms: custom transforms """ - transforms = transforms or self.val_transforms or self._default_transforms() + transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms dataset = MNIST(self.data_dir, train=True, download=False, transform=transforms) train_length = len(dataset) _, dataset_val = random_split( @@ -140,14 +140,14 @@ def val_dataloader(self, transforms=None): ) return loader - def test_dataloader(self, transforms=None): + def test_dataloader(self): """ MNIST test set uses the test split Args: transforms: custom transforms """ - transforms = transforms or self.val_transforms or self._default_transforms() + transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms dataset = MNIST(self.data_dir, train=False, download=False, transform=transforms) loader = DataLoader( @@ -156,7 +156,7 @@ def test_dataloader(self, transforms=None): ) return loader - def _default_transforms(self): + def default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]