diff --git a/docs/source/datamodules.rst b/docs/source/datamodules.rst index 5ac7ef6c1d..77c06f9c80 100644 --- a/docs/source/datamodules.rst +++ b/docs/source/datamodules.rst @@ -14,7 +14,8 @@ Step 4, also needs special care to make sure that it's only done on 1 GPU in a m In addition, there are other challenges such as models that are built using information from the dataset such as needing to know image dimensions or number of classes. -A datamodule simplifies all of these parts and integrates seamlessly into Lightning. +A datamodule simplifies all of these parts and has been integrated directly into Lightning in version 0.9.0. +You can view the documentation for the datamodule in the `Pytorch Lightning docs here. `_ .. code-block:: python @@ -92,7 +93,7 @@ Use this to build your own consistent train, validation, test splits. Example:: - from pl_bolts.datamodules import LightningDataModule + from pytorch_lightning import LightningDataModule class MyDataModule(LightningDataModule): @@ -157,12 +158,6 @@ or:: for b in dataloader: ... -DataModule class -^^^^^^^^^^^^^^^^ - -.. autoclass:: pl_bolts.datamodules.lightning_datamodule.LightningDataModule - :noindex: - ------------- DummyDataset diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index a67e85c2f5..48dcb36f42 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -1,6 +1,5 @@ from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule -from pl_bolts.datamodules.lightning_datamodule import LightningDataModule from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule from pl_bolts.datamodules.sklearn_datamodule import SklearnDataset, SklearnDataModule, TensorDataset, TensorDataModule from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 01c921a666..09a11c907c 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -1,11 +1,11 @@ from typing import Optional, Sequence +from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split from torchvision import transforms as transform_lib from torchvision.datasets import CIFAR10 from pl_bolts.datamodules.cifar10_dataset import TrialCIFAR10 -from pl_bolts.datamodules.lightning_datamodule import LightningDataModule from pl_bolts.transforms.dataset_normalizations import cifar10_normalization @@ -19,6 +19,7 @@ def __init__( data_dir, val_split=5000, num_workers=16, + batch_size=32, *args, **kwargs, ): @@ -54,6 +55,7 @@ def __init__( data_dir: where to save/load the data val_split: how many of the training images to use for the validation split num_workers: how many workers to use for loading data + batch_size: number of examples per training/eval step """ super().__init__(*args, **kwargs) self.dims = (3, 32, 32) @@ -61,6 +63,7 @@ def __init__( self.data_dir = data_dir self.val_split = val_split self.num_workers = num_workers + self.batch_size = batch_size @property def num_classes(self): @@ -77,12 +80,9 @@ def prepare_data(self): self.DATASET(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor(), **self.extra_args) self.DATASET(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor(), **self.extra_args) - def train_dataloader(self, batch_size): + def train_dataloader(self): """ CIFAR train set removes a subset to use for validation - - Args: - batch_size: size of batch """ transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms @@ -91,7 +91,7 @@ def train_dataloader(self, batch_size): dataset_train, _ = random_split(dataset, [train_length - self.val_split, self.val_split]) loader = DataLoader( dataset_train, - batch_size=batch_size, + batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, @@ -99,12 +99,9 @@ def train_dataloader(self, batch_size): ) return loader - def val_dataloader(self, batch_size): + def val_dataloader(self): """ CIFAR10 val set uses a subset of the training set for validation - - Args: - batch_size: size of batch """ transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms @@ -113,7 +110,7 @@ def val_dataloader(self, batch_size): _, dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split]) loader = DataLoader( dataset_val, - batch_size=batch_size, + batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True, @@ -121,20 +118,16 @@ def val_dataloader(self, batch_size): ) return loader - def test_dataloader(self, batch_size): + def test_dataloader(self): """ CIFAR10 test set uses the test split - - Args: - batch_size: size of batch - transforms: custom transforms """ transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms dataset = self.DATASET(self.data_dir, train=False, download=False, transform=transforms, **self.extra_args) loader = DataLoader( dataset, - batch_size=batch_size, + batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 8f02812b55..85dc17a6cb 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -1,9 +1,8 @@ +from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split from torchvision import transforms as transform_lib from torchvision.datasets import FashionMNIST -from pl_bolts.datamodules.lightning_datamodule import LightningDataModule - class FashionMNISTDataModule(LightningDataModule): diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 0ccd0a450d..48adb2ceaf 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -1,10 +1,10 @@ import os +from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from torchvision import transforms as transform_lib from pl_bolts.datamodules.imagenet_dataset import UnlabeledImagenet -from pl_bolts.datamodules.lightning_datamodule import LightningDataModule from pl_bolts.transforms.dataset_normalizations import imagenet_normalization diff --git a/pl_bolts/datamodules/lightning_datamodule.py b/pl_bolts/datamodules/lightning_datamodule.py deleted file mode 100644 index 49a6ef7728..0000000000 --- a/pl_bolts/datamodules/lightning_datamodule.py +++ /dev/null @@ -1,284 +0,0 @@ -import inspect -from abc import abstractmethod -from argparse import ArgumentParser, Namespace -from typing import Union, List, Tuple, Any - -from pytorch_lightning.utilities import rank_zero_warn, parsing -from torch.utils.data import DataLoader - - -class LightningDataModule(object): # pragma: no cover - """ - A DataModule standardizes the training, val, test splits, data preparation and transforms. - The main advantage is consistent data splits and transforms across models. - - Example:: - - class MyDataModule(LightningDataModule): - - def __init__(self): - super().__init__() - - def prepare_data(self): - # download, split, etc... - - def train_dataloader(self): - train_split = Dataset(...) - return DataLoader(train_split) - - def val_dataloader(self): - val_split = Dataset(...) - return DataLoader(val_split) - - def test_dataloader(self): - test_split = Dataset(...) - return DataLoader(test_split) - - A DataModule implements 4 key methods - - 1. **prepare_data** (things to do on 1 GPU not on every GPU in distributed mode) - 2. **train_dataloader** the training dataloader. - 3. **val_dataloader** the val dataloader. - 4. **test_dataloader** the test dataloader. - - - This allows you to share a full dataset without explaining what the splits, transforms or download - process is. - """ - name: str = ... - - def __init__( - self, - train_transforms=None, - val_transforms=None, - test_transforms=None, - ): - super().__init__() - self._train_transforms = train_transforms - self._val_transforms = val_transforms - self._test_transforms = test_transforms - self.dims = () - - @property - def train_transforms(self): - return self._train_transforms - - @train_transforms.setter - def train_transforms(self, t): - self._train_transforms = t - - @property - def val_transforms(self): - return self._val_transforms - - @val_transforms.setter - def val_transforms(self, t): - self._val_transforms = t - - @property - def test_transforms(self): - return self._test_transforms - - @test_transforms.setter - def test_transforms(self, t): - self._test_transforms = t - - def size(self, dim=None) -> Union[Tuple, int]: - """ - Return the dimension of each input - Either as a tuple or list of tuples - """ - if dim is not None: - return self.dims[dim] - - return self.dims - - @abstractmethod - def prepare_data(self, *args, **kwargs): - """ - Use this to download and prepare data. - In distributed (GPU, TPU), this will only be called once. - This is called before requesting the dataloaders: - - .. warning:: Do not assign anything to the model in this step since this will only be called on 1 GPU. - - Pseudocode:: - - model.prepare_data() - model.train_dataloader() - model.val_dataloader() - model.test_dataloader() - - Example:: - - def prepare_data(self): - download_imagenet() - clean_imagenet() - cache_imagenet() - """ - - @abstractmethod - def train_dataloader(self, *args, **kwargs) -> DataLoader: - """ - Implement a PyTorch DataLoader for training. - - Return: - Single PyTorch :class:`~torch.utils.data.DataLoader`. - - Note: - Lightning adds the correct sampler for distributed and arbitrary hardware. - There is no need to set it yourself. - - Example:: - - def train_dataloader(self): - dataset = MNIST(root=PATH, train=True, transform=transforms.ToTensor(), download=False) - loader = torch.utils.data.DataLoader(dataset=dataset) - return loader - - """ - rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer') - - @abstractmethod - def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - r""" - Implement a PyTorch DataLoader for training. - - Return: - Single PyTorch :class:`~torch.utils.data.DataLoader`. - - Note: - Lightning adds the correct sampler for distributed and arbitrary hardware. - There is no need to set it yourself. - - Note: - You can also return a list of DataLoaders - - Example:: - - def val_dataloader(self): - dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False) - loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False) - return loader - """ - - @abstractmethod - def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - r""" - Implement a PyTorch DataLoader for training. - - Return: - Single PyTorch :class:`~torch.utils.data.DataLoader`. - - Note: - Lightning adds the correct sampler for distributed and arbitrary hardware. - There is no need to set it yourself. - - Note: - You can also return a list of DataLoaders - - Example:: - - def test_dataloader(self): - dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False) - loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False) - return loader - """ - - @classmethod - def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: - r"""Extends existing argparse by default `LightningDataModule` attributes. - """ - parser = ArgumentParser(parents=[parent_parser], add_help=False,) - added_args = [x.dest for x in parser._actions] - - blacklist = ['kwargs'] - depr_arg_names = blacklist + added_args - depr_arg_names = set(depr_arg_names) - - allowed_types = (str, float, int, bool) - - # TODO: get "help" from docstring :) - for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types() - if at[0] not in depr_arg_names): - arg_types = [at for at in allowed_types if at in arg_types] - if not arg_types: - # skip argument with not supported type - continue - arg_kwargs = {} - if bool in arg_types: - arg_kwargs.update(nargs="?") - # if the only arg type is bool - if len(arg_types) == 1: - # redefine the type for ArgParser needed - def use_type(x): - return bool(parsing.str_to_bool(x)) - else: - # filter out the bool as we need to use more general - use_type = [at for at in arg_types if at is not bool][0] - else: - use_type = arg_types[0] - - if arg_default == inspect._empty: - arg_default = None - - parser.add_argument( - f'--{arg}', - dest=arg, - default=arg_default, - type=use_type, - help=f'autogenerated by plb.{cls.__name__}', - **arg_kwargs, - ) - - return parser - - @classmethod - def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): - """ - Create an instance from CLI arguments. - - Args: - args: The parser or namespace to take arguments from. Only known arguments will be - parsed and passed to the :class:`LightningDataModule`. - **kwargs: Additional keyword arguments that may override ones in the parser or namespace. - These must be valid Trainer arguments. - - Example:: - - parser = ArgumentParser(add_help=False) - parser = LightningDataModule.add_argparse_args(parser) - module = LightningDataModule.from_argparse_args(args) - """ - if isinstance(args, ArgumentParser): - args = cls.parse_argparser(args) - params = vars(args) - - # we only want to pass in valid Trainer args, the rest may be user specific - valid_kwargs = inspect.signature(cls.__init__).parameters - trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params) - trainer_kwargs.update(**kwargs) - - return cls(**trainer_kwargs) - - @classmethod - def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: - r"""Scans the Trainer signature and returns argument names, types and default values. - - Returns: - List with tuples of 3 values: - (argument name, set with argument types, argument default value). - """ - trainer_default_params = inspect.signature(cls).parameters - name_type_default = [] - for arg in trainer_default_params: - arg_type = trainer_default_params[arg].annotation - arg_default = trainer_default_params[arg].default - try: - arg_types = tuple(arg_type.__args__) - except AttributeError: - arg_types = (arg_type,) - - name_type_default.append((arg, arg_types, arg_default)) - - return name_type_default diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 8068c39a24..0113aa50ca 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -1,9 +1,8 @@ +from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split from torchvision import transforms as transform_lib from torchvision.datasets import MNIST -from pl_bolts.datamodules.lightning_datamodule import LightningDataModule - class MNISTDataModule(LightningDataModule): diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index 2da8e505f8..977c1279b2 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -4,8 +4,7 @@ import numpy as np from sklearn.utils import shuffle as sk_shuffle from torch.utils.data import Dataset, DataLoader - -from pl_bolts.datamodules.lightning_datamodule import LightningDataModule +from pytorch_lightning import LightningDataModule class SklearnDataset(Dataset): diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 35611cdb34..6cf933fbbc 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -1,10 +1,10 @@ import os +from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from torchvision import transforms as transform_lib from pl_bolts.datamodules.imagenet_dataset import UnlabeledImagenet -from pl_bolts.datamodules.lightning_datamodule import LightningDataModule from pl_bolts.transforms.dataset_normalizations import imagenet_normalization diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index f3d16c84d4..ea7675178c 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -1,9 +1,9 @@ +from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split from torchvision import transforms as transform_lib from torchvision.datasets import STL10 from pl_bolts.datamodules.concat_dataset import ConcatDataset -from pl_bolts.datamodules.lightning_datamodule import LightningDataModule from pl_bolts.transforms.dataset_normalizations import stl10_normalization diff --git a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index 9f99e7cbe3..01c9a12199 100644 --- a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -2,10 +2,10 @@ from argparse import ArgumentParser import torch -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import LightningDataModule, LightningModule, Trainer from torch.nn import functional as F -from pl_bolts.datamodules import MNISTDataModule, LightningDataModule +from pl_bolts.datamodules import MNISTDataModule from pl_bolts.models.autoencoders.basic_ae.components import AEEncoder from pl_bolts.models.autoencoders.basic_vae.components import Decoder @@ -46,7 +46,9 @@ def __init__( # link default data if datamodule is None: datamodule = MNISTDataModule(data_dir=self.hparams.data_dir, num_workers=self.hparams.num_workers) + self.datamodule = datamodule + self.img_dim = self.datamodule.size() self.encoder = self.init_encoder(self.hparams.hidden_dim, self.hparams.latent_dim, @@ -122,15 +124,6 @@ def configure_optimizers(self): def prepare_data(self): self.datamodule.prepare_data() - def train_dataloader(self): - return self.datamodule.train_dataloader(self.hparams.batch_size) - - def val_dataloader(self): - return self.datamodule.val_dataloader(self.hparams.batch_size) - - def test_dataloader(self): - return self.datamodule.test_dataloader(self.hparams.batch_size) - @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) diff --git a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py index dbc46904d4..85bd7fa48e 100644 --- a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py +++ b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py @@ -3,7 +3,7 @@ import torch import torchvision -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import LightningDataModule, LightningModule, Trainer from torch import distributions from torch.nn import functional as F @@ -25,7 +25,7 @@ def __init__( batch_size: int = 32, learning_rate: float = 0.001, data_dir: str = '.', - datamodule: pl_bolts.datamodules.LightningDataModule = None, + datamodule: LightningDataModule = None, pretrained: str = None, **kwargs ): @@ -232,15 +232,6 @@ def configure_optimizers(self): def prepare_data(self): self.datamodule.prepare_data() - def train_dataloader(self): - return self.datamodule.train_dataloader(self.hparams.batch_size) - - def val_dataloader(self): - return self.datamodule.val_dataloader(self.hparams.batch_size) - - def test_dataloader(self): - return self.datamodule.test_dataloader(self.hparams.batch_size) - @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) diff --git a/pl_bolts/models/gans/basic/basic_gan_module.py b/pl_bolts/models/gans/basic/basic_gan_module.py index f306bd1a5a..4dc7ed5e7f 100644 --- a/pl_bolts/models/gans/basic/basic_gan_module.py +++ b/pl_bolts/models/gans/basic/basic_gan_module.py @@ -2,11 +2,11 @@ from collections import OrderedDict import torch -from pytorch_lightning import Trainer, LightningModule, Callback +from pytorch_lightning import Trainer, LightningDataModule, LightningModule, Callback from pytorch_lightning.callbacks import ModelCheckpoint from torch.nn import functional as F -from pl_bolts.datamodules import MNISTDataModule, LightningDataModule, STL10DataModule +from pl_bolts.datamodules import MNISTDataModule, STL10DataModule from pl_bolts.callbacks import LatentDimInterpolator from pl_bolts.models.gans.basic.components import Generator, Discriminator import os @@ -184,9 +184,6 @@ def configure_optimizers(self): def prepare_data(self): self.datamodule.prepare_data() - def train_dataloader(self): - return self.datamodule.train_dataloader(self.hparams.batch_size) - @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) diff --git a/pl_bolts/models/self_supervised/amdim/amdim_module.py b/pl_bolts/models/self_supervised/amdim/amdim_module.py index e32bdd5f31..0d6757bc38 100644 --- a/pl_bolts/models/self_supervised/amdim/amdim_module.py +++ b/pl_bolts/models/self_supervised/amdim/amdim_module.py @@ -18,7 +18,7 @@ class AMDIM(pl.LightningModule): def __init__( self, - datamodule: Union[str, pl_bolts.datamodules.LightningDataModule] = 'cifar10', + datamodule: Union[str, pl.LightningDataModule] = 'cifar10', encoder: Union[str, torch.nn.Module, pl.LightningModule] = 'amdim_encoder', contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask('01, 02, 11'), image_channels: int = 3, diff --git a/pl_bolts/models/self_supervised/cpc/cpc_module.py b/pl_bolts/models/self_supervised/cpc/cpc_module.py index a188dc7a24..61d09a8e98 100644 --- a/pl_bolts/models/self_supervised/cpc/cpc_module.py +++ b/pl_bolts/models/self_supervised/cpc/cpc_module.py @@ -39,7 +39,7 @@ class CPCV2(pl.LightningModule): def __init__( self, - datamodule: pl_bolts.datamodules.LightningDataModule = None, + datamodule: pl.LightningDataModule = None, encoder: Union[str, torch.nn.Module, pl.LightningModule] = 'cpc_encoder', patch_size: int = 8, patch_overlap: int = 4, @@ -128,7 +128,11 @@ def __init__( # link data if datamodule is None: - datamodule = CIFAR10DataModule(self.hparams.data_dir, num_workers=self.hparams.num_workers) + datamodule = CIFAR10DataModule( + self.hparams.data_dir, + num_workers=self.hparams.num_workers, + batch_size=batch_size + ) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() self.datamodule = datamodule @@ -315,14 +319,6 @@ def configure_optimizers(self): def prepare_data(self): self.datamodule.prepare_data() - def train_dataloader(self): - loader = self.datamodule.train_dataloader(self.hparams.batch_size) - return loader - - def val_dataloader(self): - loader = self.datamodule.val_dataloader(self.hparams.batch_size) - return loader - @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index cb5f9da3c1..81e3236af8 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -38,7 +38,7 @@ def __init__(self, learning_rate: float = 0.03, momentum: float = 0.9, weight_decay: float = 1e-4, - datamodule: pl_bolts.datamodules.LightningDataModule = None, + datamodule: pl.LightningDataModule = None, data_dir: str = './', batch_size: int = 256, use_mlp: bool = False, @@ -320,14 +320,6 @@ def configure_optimizers(self): def prepare_data(self): self.datamodule.prepare_data() - def train_dataloader(self): - loader = self.datamodule.train_dataloader(self.hparams.batch_size) - return loader - - def val_dataloader(self): - loader = self.datamodule.val_dataloader(self.hparams.batch_size) - return loader - @staticmethod def add_model_specific_args(parent_parser): from test_tube import HyperOptArgumentParser diff --git a/pl_bolts/models/self_supervised/simclr/simclr_module.py b/pl_bolts/models/self_supervised/simclr/simclr_module.py index 7f4933451e..c86c993933 100644 --- a/pl_bolts/models/self_supervised/simclr/simclr_module.py +++ b/pl_bolts/models/self_supervised/simclr/simclr_module.py @@ -46,8 +46,8 @@ def forward(self, x): class SimCLR(pl.LightningModule): def __init__(self, - datamodule: pl_bolts.datamodules.LightningDataModule = None, - data_dir: str = '', + datamodule: pl.LightningDataModule = None, + data_dir: str = './', learning_rate: float = 0.00006, weight_decay: float = 0.0005, input_height: int = 32, @@ -117,11 +117,12 @@ def __init__(self, # init default datamodule if datamodule is None: - datamodule = CIFAR10DataModule(data_dir, num_workers=num_workers) + datamodule = CIFAR10DataModule(data_dir, num_workers=num_workers, batch_size=batch_size) datamodule.train_transforms = SimCLRTrainDataTransform(input_height) datamodule.val_transforms = SimCLREvalDataTransform(input_height) self.datamodule = datamodule + self.loss_func = self.init_loss() self.encoder = self.init_encoder() self.projection = self.init_projection() @@ -234,12 +235,6 @@ def validation_epoch_end(self, outputs: list): def prepare_data(self): self.datamodule.prepare_data() - def train_dataloader(self): - return self.datamodule.train_dataloader(self.hparams.batch_size) - - def val_dataloader(self): - return self.datamodule.val_dataloader(self.hparams.batch_size) - def configure_optimizers(self): if self.hparams.optimizer == 'adam': optimizer = torch.optim.Adam( diff --git a/pl_bolts/models/vision/image_gpt/igpt_module.py b/pl_bolts/models/vision/image_gpt/igpt_module.py index 79a26ebeee..35d7d51caa 100644 --- a/pl_bolts/models/vision/image_gpt/igpt_module.py +++ b/pl_bolts/models/vision/image_gpt/igpt_module.py @@ -6,7 +6,6 @@ import numpy as np from pl_bolts.datamodules import FashionMNISTDataModule, ImagenetDataModule from pl_bolts.models.vision.image_gpt.gpt2 import GPT2 -from pl_bolts.datamodules import LightningDataModule def _shape_input(x): @@ -19,7 +18,7 @@ def _shape_input(x): class ImageGPT(pl.LightningModule): def __init__( self, - datamodule: LightningDataModule = None, + datamodule: pl.LightningDataModule = None, embed_dim: int = 16, heads: int = 2, layers: int = 2, @@ -229,15 +228,6 @@ def test_epoch_end(self, outs): def prepare_data(self): self.datamodule.prepare_data() - def train_dataloader(self): - return self.datamodule.train_dataloader() - - def val_dataloader(self): - return self.datamodule.val_dataloader() - - def test_dataloader(self): - return self.datamodule.test_dataloader() - @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) diff --git a/requirements.txt b/requirements.txt index df55815fcb..bc7674b02b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pytorch-lightning>=0.8.4 +pytorch-lightning>=0.9.0rc3 torch>=1.4 torchvision>=0.5 scikit-learn>=0.23 diff --git a/tests/models/test_autoencoders.py b/tests/models/test_autoencoders.py index 7e2205db34..045c814511 100644 --- a/tests/models/test_autoencoders.py +++ b/tests/models/test_autoencoders.py @@ -1,6 +1,9 @@ import pytorch_lightning as pl +import torch from pl_bolts.models.autoencoders import VAE, AE +from pl_bolts.models.autoencoders.basic_ae import AEEncoder +from pl_bolts.models.autoencoders.basic_vae import Encoder, Decoder from tests import reset_seed @@ -23,3 +26,40 @@ def test_ae(tmpdir): trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model) trainer.test(model) + + +def test_basic_ae_encoder(tmpdir): + reset_seed() + + hidden_dim = 128 + latent_dim = 2 + width = height = 28 + batch_size = 16 + channels = 1 + + encoder = AEEncoder(hidden_dim, latent_dim, width, height) + x = torch.randn(batch_size, channels, width, height) + z = encoder(x) + + assert z.shape == (batch_size, latent_dim) + + +def test_basic_vae_components(tmpdir): + reset_seed() + + hidden_dim = 128 + latent_dim = 2 + width = height = 28 + batch_size = 16 + channels = 1 + + enc = Encoder(hidden_dim, latent_dim, channels, width, height) + x = torch.randn(batch_size, channels, width, height) + mu, sigma = enc(x) + + assert mu.shape == sigma.shape + + dec = Decoder(hidden_dim, latent_dim, width, height, channels) + decoded_x = dec(mu) + + assert decoded_x.view(-1).shape == x.view(-1).shape diff --git a/tests/models/test_self_supervised.py b/tests/models/test_self_supervised.py index 562cfb9292..95679d1a5d 100644 --- a/tests/models/test_self_supervised.py +++ b/tests/models/test_self_supervised.py @@ -12,7 +12,7 @@ def test_cpcv2(tmpdir): reset_seed() - datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0) + datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0, batch_size=2) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() @@ -38,7 +38,7 @@ def test_amdim(tmpdir): def test_moco(tmpdir): reset_seed() - datamodule = CIFAR10DataModule(tmpdir, num_workers=0) + datamodule = CIFAR10DataModule(tmpdir, num_workers=0, batch_size=2) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() @@ -53,7 +53,7 @@ def test_moco(tmpdir): def test_simclr(tmpdir): reset_seed() - datamodule = CIFAR10DataModule(tmpdir, num_workers=0) + datamodule = CIFAR10DataModule(tmpdir, num_workers=0, batch_size=2) datamodule.train_transforms = SimCLRTrainDataTransform(32) datamodule.val_transforms = SimCLREvalDataTransform(32) diff --git a/tests/optimizers/__init__.py b/tests/optimizers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/optimizers/test_layer_adaptive_scaling.py b/tests/optimizers/test_layer_adaptive_scaling.py new file mode 100644 index 0000000000..5c8e40ba8a --- /dev/null +++ b/tests/optimizers/test_layer_adaptive_scaling.py @@ -0,0 +1,45 @@ +import pytest + +from pl_bolts.models import LitMNIST +from pl_bolts.optimizers.layer_adaptive_scaling import LARS, REQUIRED +from tests import reset_seed + + +def test_lars_lr_greater_than_zero(tmpdir): + reset_seed() + + model = LitMNIST() + with pytest.raises(ValueError, match='Invalid learning rate.*'): + opt = LARS(model.parameters(), lr=-0.5) + + opt = LARS(model.parameters(), lr=0.003) + + +def test_lars_momentum_greater_than_zero(tmpdir): + reset_seed() + + model = LitMNIST() + with pytest.raises(ValueError, match='Invalid momentum.*'): + opt = LARS(model.parameters(), lr=0.003, momentum=-0.01) + + opt = LARS(model.parameters(), lr=0.003, momentum=0.01) + + +def test_lars_weight_decay_greater_than_zero(tmpdir): + reset_seed() + + model = LitMNIST() + with pytest.raises(ValueError, match='Invalid weight_decay.*'): + opt = LARS(model.parameters(), lr=0.003, weight_decay=-0.01) + + opt = LARS(model.parameters(), lr=0.003, weight_decay=0.01) + + +def test_lars_eta_greater_than_zero(tmpdir): + reset_seed() + + model = LitMNIST() + with pytest.raises(ValueError, match='Invalid LARS coefficient.*'): + opt = LARS(model.parameters(), lr=0.003, eta=-0.01) + + opt = LARS(model.parameters(), lr=0.003, eta=0.01)