diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index 09783e6d18382..cee123f4dc78e 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -7,8 +7,10 @@ __copyright__ = 'Copyright (c) 2018-2020, %s.' % __author__ __homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning' # this has to be simple string, see: https://github.com/pypa/twine/issues/522 -__docs__ = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." \ - " Scale your models. Write less boilerplate." +__docs__ = ( + "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." + " Scale your models. Write less boilerplate." +) __long_docs__ = """ Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. It's more of a style-guide than a framework. @@ -47,10 +49,11 @@ if __LIGHTNING_SETUP__: import sys # pragma: no-cover + sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover # We are not importing the rest of the lightning during the build process, as it may not be compiled yet else: - from pytorch_lightning.core import LightningModule, data_loader + from pytorch_lightning.core import LightningDataModule, LightningModule, data_loader from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer import Trainer from pytorch_lightning.utilities.seed import seed_everything @@ -59,13 +62,14 @@ __all__ = [ 'Trainer', + 'LightningDataModule', 'LightningModule', 'Callback', 'data_loader', 'seed_everything', 'metrics', 'EvalResult', - 'TrainResult' + 'TrainResult', ] # necessary for regular bolts imports. Skip exception since bolts is not always installed diff --git a/pytorch_lightning/core/__init__.py b/pytorch_lightning/core/__init__.py index e7a4322d689c7..ae040baf4909e 100644 --- a/pytorch_lightning/core/__init__.py +++ b/pytorch_lightning/core/__init__.py @@ -336,8 +336,9 @@ def training_step(self, batch, batch_idx): """ +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.decorators import data_loader from pytorch_lightning.core.lightning import LightningModule -__all__ = ['LightningModule', 'data_loader'] +__all__ = ['LightningDataModule', 'LightningModule', 'data_loader'] # __call__ = __all__ diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py new file mode 100644 index 0000000000000..953c14a96d330 --- /dev/null +++ b/pytorch_lightning/core/datamodule.py @@ -0,0 +1,314 @@ +import inspect +from abc import abstractmethod +from argparse import ArgumentParser, Namespace +from typing import Any, List, Tuple, Union + +from torch.utils.data import DataLoader + +from pytorch_lightning.utilities import parsing, rank_zero_only, rank_zero_warn + + +class _DataModuleWrapper(type): + def __call__(cls, *args, **kwargs): + """A wrapper for LightningDataModule that: + + 1. Runs user defined subclass's __init__ + 2. Assures prepare_data() runs on rank 0 + """ + + # Wrap cls's prepare_data function with rank_zero_only + cls.prepare_data = rank_zero_only(cls.prepare_data) + + # Get instance of LightningDataModule by mocking its __init__ via __call__ + obj = type.__call__(cls, *args, **kwargs) + + return obj + + +class LightningDataModule(object, metaclass=_DataModuleWrapper): # pragma: no cover + """ + A DataModule standardizes the training, val, test splits, data preparation and transforms. + The main advantage is consistent data splits, data preparation and transforms across models. + + Example:: + + class MyDataModule(LightningDataModule): + def __init__(self): + super().__init__() + def prepare_data(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + def setup(self): + # make assignments here (val/train/test split) + # called on every process in DDP + def train_dataloader(self): + train_split = Dataset(...) + return DataLoader(train_split) + def val_dataloader(self): + val_split = Dataset(...) + return DataLoader(val_split) + def test_dataloader(self): + test_split = Dataset(...) + return DataLoader(test_split) + + A DataModule implements 5 key methods: + + * **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode). + * **setup** (things to do on every accelerator in distributed mode). + * **train_dataloader** the training dataloader. + * **val_dataloader** the val dataloader(s). + * **test_dataloader** the test dataloader(s). + + + This allows you to share a full dataset without explaining how to download, + split transform and process the data + + """ + + name: str = ... + + def __init__( + self, train_transforms=None, val_transforms=None, test_transforms=None, + ): + super().__init__() + self._train_transforms = train_transforms + self._val_transforms = val_transforms + self._test_transforms = test_transforms + self.dims = () + + @property + def train_transforms(self): + """ + Optional transforms (or collection of transforms) you can apply to train dataset + """ + return self._train_transforms + + @train_transforms.setter + def train_transforms(self, t): + self._train_transforms = t + + @property + def val_transforms(self): + """ + Optional transforms (or collection of transforms) you can apply to validation dataset + """ + return self._val_transforms + + @val_transforms.setter + def val_transforms(self, t): + self._val_transforms = t + + @property + def test_transforms(self): + """ + Optional transforms (or collection of transforms) you can apply to test dataset + """ + return self._test_transforms + + @test_transforms.setter + def test_transforms(self, t): + self._test_transforms = t + + def size(self, dim=None) -> Union[Tuple, int]: + """ + Return the dimension of each input either as a tuple or list of tuples. + """ + + if dim is not None: + return self.dims[dim] + + return self.dims + + @abstractmethod + def prepare_data(self, *args, **kwargs): + """ + Use this to download and prepare data. + In distributed (GPU, TPU), this will only be called once. + + .. warning:: Do not assign anything to the datamodule in this step since this will only be called on 1 GPU. + + Pseudocode:: + + dm.prepare_data() + dm.setup() + + Example:: + + def prepare_data(self): + download_imagenet() + clean_imagenet() + cache_imagenet() + """ + + @abstractmethod + def setup(self, *args, **kwargs): + """ + Use this to load your data from file, split it, etc. You are safe to make state assignments here. + This hook is called on every process when using DDP. + + Example:: + + def setup(self): + data = load_data(...) + self.train_ds, self.val_ds, self.test_ds = split_data(data) + """ + + @abstractmethod + def train_dataloader(self, *args, **kwargs) -> DataLoader: + """ + Implement a PyTorch DataLoader for training. + Return: + Single PyTorch :class:`~torch.utils.data.DataLoader`. + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware. + There is no need to set it yourself. + + Example:: + + def train_dataloader(self): + dataset = MNIST(root=PATH, train=True, transform=transforms.ToTensor(), download=False) + loader = torch.utils.data.DataLoader(dataset=dataset) + return loader + + """ + rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer') + + @abstractmethod + def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + r""" + Implement a PyTorch DataLoader for training. + Return: + Single PyTorch :class:`~torch.utils.data.DataLoader`. + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware. + There is no need to set it yourself. + Note: + You can also return a list of DataLoaders + + Example:: + + def val_dataloader(self): + dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False) + loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False) + return loader + """ + + @abstractmethod + def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + r""" + Implement a PyTorch DataLoader for training. + Return: + Single PyTorch :class:`~torch.utils.data.DataLoader`. + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware. + There is no need to set it yourself. + Note: + You can also return a list of DataLoaders + + Example:: + + def test_dataloader(self): + dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False) + loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False) + return loader + """ + + @classmethod + def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: + r"""Extends existing argparse by default `LightningDataModule` attributes. + """ + parser = ArgumentParser(parents=[parent_parser], add_help=False,) + added_args = [x.dest for x in parser._actions] + + blacklist = ['kwargs'] + depr_arg_names = blacklist + added_args + depr_arg_names = set(depr_arg_names) + + allowed_types = (str, float, int, bool) + + # TODO: get "help" from docstring :) + for arg, arg_types, arg_default in ( + at for at in cls.get_init_arguments_and_types() if at[0] not in depr_arg_names + ): + arg_types = [at for at in allowed_types if at in arg_types] + if not arg_types: + # skip argument with not supported type + continue + arg_kwargs = {} + if bool in arg_types: + arg_kwargs.update(nargs="?") + # if the only arg type is bool + if len(arg_types) == 1: + # redefine the type for ArgParser needed + def use_type(x): + return bool(parsing.str_to_bool(x)) + + else: + # filter out the bool as we need to use more general + use_type = [at for at in arg_types if at is not bool][0] + else: + use_type = arg_types[0] + + if arg_default == inspect._empty: + arg_default = None + + parser.add_argument( + f'--{arg}', + dest=arg, + default=arg_default, + type=use_type, + help=f'autogenerated by plb.{cls.__name__}', + **arg_kwargs, + ) + + return parser + + @classmethod + def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): + """ + Create an instance from CLI arguments. + + Args: + args: The parser or namespace to take arguments from. Only known arguments will be + parsed and passed to the :class:`LightningDataModule`. + **kwargs: Additional keyword arguments that may override ones in the parser or namespace. + These must be valid DataModule arguments. + + Example:: + + parser = ArgumentParser(add_help=False) + parser = LightningDataModule.add_argparse_args(parser) + module = LightningDataModule.from_argparse_args(args) + + """ + if isinstance(args, ArgumentParser): + args = cls.parse_argparser(args) + params = vars(args) + + # we only want to pass in valid DataModule args, the rest may be user specific + valid_kwargs = inspect.signature(cls.__init__).parameters + datamodule_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params) + datamodule_kwargs.update(**kwargs) + + return cls(**datamodule_kwargs) + + @classmethod + def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: + r"""Scans the DataModule signature and returns argument names, types and default values. + Returns: + List with tuples of 3 values: + (argument name, set with argument types, argument default value). + """ + datamodule_default_params = inspect.signature(cls.__init__).parameters + name_type_default = [] + for arg in datamodule_default_params: + arg_type = datamodule_default_params[arg].annotation + arg_default = datamodule_default_params[arg].default + try: + arg_types = tuple(arg_type.__args__) + except AttributeError: + arg_types = (arg_type,) + + name_type_default.append((arg, arg_types, arg_default)) + + return name_type_default diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ce747c66be411..c753ce2fcd5db 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -4,7 +4,7 @@ import re from abc import ABC, abstractmethod from argparse import Namespace -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.distributed as torch_distrib @@ -18,10 +18,10 @@ from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import ModelHooks from pytorch_lightning.core.memory import ModelSummary -from pytorch_lightning.core.saving import ModelIO, PRIMITIVE_TYPES, ALLOWED_CONFIG_TYPES -from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args try: @@ -33,7 +33,6 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks, Module): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -70,6 +69,7 @@ def __init__(self, *args, **kwargs): # optionally can be set by user self._example_input_array = None + self._datamodule = None @property def example_input_array(self) -> Any: @@ -79,6 +79,14 @@ def example_input_array(self) -> Any: def example_input_array(self, example: Any) -> None: self._example_input_array = example + @property + def datamodule(self) -> Any: + return self._datamodule + + @datamodule.setter + def datamodule(self, datamodule: Any) -> None: + self._datamodule = datamodule + @property def on_gpu(self): """ @@ -159,9 +167,7 @@ def forward(self, batch): """ - def training_step(self, *args, **kwargs) -> Union[ - int, Dict[str, Union[Tensor, Dict[str, Tensor]]] - ]: + def training_step(self, *args, **kwargs) -> Union[int, Dict[str, Union[Tensor, Dict[str, Tensor]]]]: r""" Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger. @@ -252,8 +258,7 @@ def training_end(self, *args, **kwargs): """ def training_epoch_end( - self, - outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] + self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] ) -> Dict[str, Dict[str, Tensor]]: """Called at the end of the training epoch with the outputs of all training steps. @@ -328,9 +333,7 @@ def training_epoch_end(self, outputs): return results """ - def training_step_end(self, *args, **kwargs) -> Dict[ - str, Union[Tensor, Dict[str, Tensor]] - ]: + def training_step_end(self, *args, **kwargs) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]: """ Use this when training with dp or ddp2 because :meth:`training_step` will operate on only part of the batch. However, this is still optional @@ -549,8 +552,7 @@ def validation_end(self, outputs): """ def validation_epoch_end( - self, - outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] + self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] ) -> Dict[str, Dict[str, Tensor]]: """ Called at the end of the validation epoch with the outputs of all validation steps. @@ -776,8 +778,7 @@ def test_end(self, outputs): """ def test_epoch_end( - self, - outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] + self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] ) -> Dict[str, Dict[str, Tensor]]: """ Called at the end of a test epoch with the output of all test steps. @@ -853,11 +854,7 @@ def test_epoch_end(self, outputs): return results """ - def configure_ddp( - self, - model: 'LightningModule', - device_ids: List[int] - ) -> DistributedDataParallel: + def configure_ddp(self, model: 'LightningModule', device_ids: List[int]) -> DistributedDataParallel: r""" Override to init DDP in your own way or with your own wrapper. The only requirements are that: @@ -887,11 +884,7 @@ def configure_ddp(self, model, device_ids): return model """ - model = LightningDistributedDataParallel( - model, - device_ids=device_ids, - find_unused_parameters=True - ) + model = LightningDistributedDataParallel(model, device_ids=device_ids, find_unused_parameters=True) return model def _init_slurm_connection(self) -> None: @@ -927,12 +920,7 @@ def _init_slurm_connection(self) -> None: root_node = self.trainer.resolve_root_node_address(root_node) os.environ['MASTER_ADDR'] = root_node - def init_ddp_connection( - self, - global_rank: int, - world_size: int, - is_slurm_managing_tasks: bool = True - ) -> None: + def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True) -> None: """ Override to define your custom way of setting up a distributed environment. @@ -959,19 +947,17 @@ def init_ddp_connection( log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) != world_size: - rank_zero_warn(f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) " - f"is not equal to the computed world size ({world_size}). Ignored.") + rank_zero_warn( + f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) " + f"is not equal to the computed world size ({world_size}). Ignored." + ) torch_backend = "nccl" if self.trainer.on_gpu else "gloo" log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}") torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) def configure_apex( - self, - amp: object, - model: 'LightningModule', - optimizers: List[Optimizer], - amp_level: str + self, amp: object, model: 'LightningModule', optimizers: List[Optimizer], amp_level: str ) -> Tuple['LightningModule', List[Optimizer]]: r""" Override to init AMP your own way. @@ -1001,9 +987,9 @@ def configure_apex(self, amp, model, optimizers, amp_level): return model, optimizers - def configure_optimizers(self) -> Optional[Union[ - Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List] - ]]: + def configure_optimizers( + self, + ) -> Optional[Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]]]: r""" Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. @@ -1120,15 +1106,15 @@ def configure_optimizers(self): rank_zero_warn('`configure_optimizers` must be implemented to be used with the Lightning Trainer') def optimizer_step( - self, - epoch: int, - batch_idx: int, - optimizer: Optimizer, - optimizer_idx: int, - second_order_closure: Optional[Callable] = None, - on_tpu: bool = False, - using_native_amp: bool = False, - using_lbfgs: bool = False, + self, + epoch: int, + batch_idx: int, + optimizer: Optimizer, + optimizer_idx: int, + second_order_closure: Optional[Callable] = None, + on_tpu: bool = False, + using_native_amp: bool = False, + using_lbfgs: bool = False, ) -> None: r""" Override this method to adjust the default way the @@ -1205,11 +1191,7 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, else: optimizer.step() - def optimizer_zero_grad(self, - epoch: int, - batch_idx: int, - optimizer: Optimizer, - optimizer_idx: int): + def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): optimizer.zero_grad() def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list: @@ -1264,11 +1246,11 @@ def tbptt_split_batch(self, batch, split_size): batch_split = [] for i, x in enumerate(batch): if isinstance(x, torch.Tensor): - split_x = x[:, t:t + split_size] + split_x = x[:, t : t + split_size] elif isinstance(x, collections.Sequence): split_x = [None] * len(x) for batch_idx in range(len(x)): - split_x[batch_idx] = x[batch_idx][t:t + split_size] + split_x[batch_idx] = x[batch_idx][t : t + split_size] batch_split.append(split_x) @@ -1374,8 +1356,11 @@ def tng_dataloader(self): # todo: remove in v1.0.0 Deprecated in v0.5.0. Use :meth:`train_dataloader` instead. Will be removed in 1.0.0. """ output = self.train_dataloader() - rank_zero_warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0." - " and this method will be removed in v1.0.0", DeprecationWarning) + rank_zero_warn( + "`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0." + " and this method will be removed in v1.0.0", + DeprecationWarning, + ) return output def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: @@ -1581,9 +1566,7 @@ def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]: # call .item() only once but store elements without graphs running_train_loss = self.trainer.running_loss.mean() avg_training_loss = running_train_loss.cpu().item() if running_train_loss is not None else float('NaN') - tqdm_dict = { - 'loss': '{:.3f}'.format(avg_training_loss) - } + tqdm_dict = {'loss': '{:.3f}'.format(avg_training_loss)} if self.trainer.truncated_bptt_steps is not None: tqdm_dict['split_idx'] = self.trainer.split_idx @@ -1604,8 +1587,11 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: Deprecated since v0.7.3. Use :meth:`get_progress_bar_dict` instead. """ - rank_zero_warn("`get_tqdm_dict` was renamed to `get_progress_bar_dict` in v0.7.3" - " and this method will be removed in v1.0.0", DeprecationWarning) + rank_zero_warn( + "`get_tqdm_dict` was renamed to `get_progress_bar_dict` in v0.7.3" + " and this method will be removed in v1.0.0", + DeprecationWarning, + ) return self.get_progress_bar_dict() @classmethod diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index ec2d1274ef47f..207cdddda0488 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -1,11 +1,11 @@ import inspect from abc import ABC, abstractmethod +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule class TrainerModelHooksMixin(ABC): - def is_function_implemented(self, f_name, model=None): if model is None: model = self.get_model() @@ -15,7 +15,9 @@ def is_function_implemented(self, f_name, model=None): def is_overridden(self, method_name: str, model: LightningModule = None) -> bool: if model is None: model = self.get_model() - super_object = LightningModule + # if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super + # TODO - refector this function to accept model_name, instance, parent so it makes more sense + super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule # assert model, 'no model passes' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bee332f831cce..eef61fc7dbd4b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1,45 +1,46 @@ import inspect import os +import warnings from argparse import ArgumentParser, Namespace -from typing import Union, Optional, List, Dict, Tuple, Iterable, Any +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.distributed as torch_distrib import torch.multiprocessing as mp from torch.utils.data import DataLoader -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback +from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary +from pytorch_lightning.core.step_result import EvalResult from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler -from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin, NATIVE_AMP_AVALAIBLE +from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler +from pytorch_lightning.trainer.auto_mix_precision import NATIVE_AMP_AVALAIBLE, TrainerAMPMixin from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin -from pytorch_lightning.trainer.deprecated_api import ( - TrainerDeprecatedAPITillVer0_9, TrainerDeprecatedAPITillVer0_10) +from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_9, TrainerDeprecatedAPITillVer0_10 from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin -from pytorch_lightning.trainer.distrib_parts import ( - TrainerDPMixin, _parse_gpu_ids, determine_root_gpu_device, pick_multiple_gpus, _parse_tpu_cores) +from pytorch_lightning.trainer.distrib_parts import (TrainerDPMixin, _parse_gpu_ids, _parse_tpu_cores, + determine_root_gpu_device, pick_multiple_gpus) from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin from pytorch_lightning.trainer.logging import TrainerLoggingMixin +from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin -from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only +from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.debugging import InternalDebugger -from pytorch_lightning.core.step_result import EvalResult -import warnings +from pytorch_lightning.utilities.exceptions import MisconfigurationException # warnings to ignore in trainer -warnings.filterwarnings('ignore', message='torch.distributed.reduce_op is deprecated, ' - 'please use torch.distributed.ReduceOp instead') +warnings.filterwarnings( + 'ignore', message='torch.distributed.reduce_op is deprecated, ' 'please use torch.distributed.ReduceOp instead' +) try: from apex import amp @@ -134,6 +135,7 @@ class Trainer( >>> len(test_outputs) 25 """ + DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar', 'training_tqdm_dict', 'num_tpu_cores') def __init__( @@ -191,7 +193,7 @@ def __init__( val_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 test_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 train_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 - overfit_pct: float = None # backward compatible, todo: remove in v1.0.0 + overfit_pct: float = None, # backward compatible, todo: remove in v1.0.0 ): r""" @@ -459,16 +461,18 @@ def __init__( self.check_val_every_n_epoch = check_val_every_n_epoch if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf': - raise MisconfigurationException( - "track_grad_norm can be an int, a float or 'inf' (infinity norm).") + raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).") self.track_grad_norm = float(track_grad_norm) self.on_gpu = True if (gpus and torch.cuda.is_available()) else False # tpu config if num_tpu_cores is not None: - rank_zero_warn("Argument `num_tpu_cores` is now set by `tpu_cores` since v0.7.6" - " and this argument will be removed in v0.9.0", DeprecationWarning) + rank_zero_warn( + "Argument `num_tpu_cores` is now set by `tpu_cores` since v0.7.6" + " and this argument will be removed in v0.9.0", + DeprecationWarning, + ) if tpu_cores is None: tpu_cores = num_tpu_cores @@ -491,8 +495,11 @@ def __init__( self.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps # Backward compatibility, TODO: remove in v0.9.0 if print_nan_grads: - rank_zero_warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0." - " NaN grads will be printed automatically when detected.", DeprecationWarning) + rank_zero_warn( + "Argument `print_nan_grads` has no effect and will be removed in v0.9.0." + " NaN grads will be printed automatically when detected.", + DeprecationWarning, + ) self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch @@ -510,8 +517,9 @@ def __init__( if self.fast_dev_run: self.num_sanity_val_steps = 0 self.max_epochs = 1 - rank_zero_info('Running in fast_dev_run mode: will run a full train,' - ' val and test loop using a single batch') + rank_zero_info( + 'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch' + ) # configure profiler if profiler is True: @@ -571,8 +579,11 @@ def __init__( # how much of the data to use # TODO: remove in 0.10.0 if overfit_pct is not None: - rank_zero_warn("Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0" - " and this argument will be removed in v0.10.0", DeprecationWarning) + rank_zero_warn( + "Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0" + " and this argument will be removed in v0.10.0", + DeprecationWarning, + ) overfit_batches = overfit_pct # convert floats to ints @@ -580,20 +591,29 @@ def __init__( # TODO: remove in 0.10.0 if val_percent_check is not None: - rank_zero_warn("Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0" - " and this argument will be removed in v0.10.0", DeprecationWarning) + rank_zero_warn( + "Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0" + " and this argument will be removed in v0.10.0", + DeprecationWarning, + ) limit_val_batches = val_percent_check # TODO: remove in 0.10.0 if test_percent_check is not None: - rank_zero_warn("Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0" - " and this argument will be removed in v0.10.0", DeprecationWarning) + rank_zero_warn( + "Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0" + " and this argument will be removed in v0.10.0", + DeprecationWarning, + ) limit_test_batches = test_percent_check # TODO: remove in 0.10.0 if train_percent_check is not None: - rank_zero_warn("Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0" - " and this argument will be removed in v0.10.0", DeprecationWarning) + rank_zero_warn( + "Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0" + " and this argument will be removed in v0.10.0", + DeprecationWarning, + ) limit_train_batches = train_percent_check self.limit_test_batches = _determine_limit_batches(limit_test_batches) @@ -610,8 +630,11 @@ def __init__( # Backward compatibility, TODO: remove in v0.9.0 if use_amp is not None: - rank_zero_warn("Argument `use_amp` is now set by `precision` since v0.7.0" - " and this method will be removed in v0.9.0", DeprecationWarning) + rank_zero_warn( + "Argument `use_amp` is now set by `precision` since v0.7.0" + " and this method will be removed in v0.9.0", + DeprecationWarning, + ) self.precision = 16 if use_amp else 32 self.amp_level = amp_level @@ -751,7 +774,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: ...} """ - parser = ArgumentParser(parents=[parent_parser], add_help=False, ) + parser = ArgumentParser(parents=[parent_parser], add_help=False,) blacklist = ['kwargs'] depr_arg_names = cls.get_deprecated_arg_names() + blacklist @@ -759,8 +782,9 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: allowed_types = (str, float, int, bool) # TODO: get "help" from docstring :) - for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types() - if at[0] not in depr_arg_names): + for arg, arg_types, arg_default in ( + at for at in cls.get_init_arguments_and_types() if at[0] not in depr_arg_names + ): arg_types = [at for at in allowed_types if at in arg_types] if not arg_types: # skip argument with not supported type @@ -773,6 +797,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: # redefine the type for ArgParser needed def use_type(x): return bool(parsing.str_to_bool(x)) + else: # filter out the bool as we need to use more general use_type = [at for at in arg_types if at is not bool][0] @@ -891,17 +916,18 @@ def disable_validation(self) -> bool: @property def enable_validation(self) -> bool: """ Check if we should run validation during training. """ - val_loop_enabled = (self.is_overridden('validation_step') and self.limit_val_batches > 0) + val_loop_enabled = self.is_overridden('validation_step') and self.limit_val_batches > 0 return val_loop_enabled or self.fast_dev_run # ----------------------------- # MODEL TRAINING # ----------------------------- def fit( - self, - model: LightningModule, - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None + self, + model: LightningModule, + train_dataloader: Optional[DataLoader] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional[LightningDataModule] = None, ): r""" Runs the full optimization routine. @@ -948,8 +974,15 @@ def fit( if hasattr(model, 'hparams'): parsing.clean_namespace(model.hparams) + # If you supply a datamodule you can't supply train_dataloader or val_dataloaders + if (train_dataloader or val_dataloaders) and datamodule: + raise MisconfigurationException( + 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' + ) + # set up the passed in dataloaders (if needed) self.__attach_dataloaders(model, train_dataloader, val_dataloaders) + self.__attach_datamodule(model, datamodule) # check that model is configured correctly self.check_model_configuration(model) @@ -1086,7 +1119,7 @@ def __run_ddp_spawn(self, model, nprocs): smp = mp.get_context('spawn') q = smp.SimpleQueue() - mp.spawn(self.ddp_train, nprocs=nprocs, args=(q, model, )) + mp.spawn(self.ddp_train, nprocs=nprocs, args=(q, model,)) # restore main state with best weights best_path = q.get() @@ -1122,6 +1155,20 @@ def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=Non if test_dataloaders is not None: model.test_dataloader = _PatchDataLoader(test_dataloaders) + def __attach_datamodule(self, model, datamodule=None): + + # We use datamodule if it's been provided on .fit or .test, otherwise we check model for it + datamodule = datamodule or getattr(model, 'datamodule', None) + + # If we have a datamodule, attach necessary hooks + dataloaders + if datamodule: + if self.is_overridden('train_dataloader', datamodule): + model.train_dataloader = datamodule.train_dataloader + if self.is_overridden('val_dataloader', datamodule): + model.val_dataloader = datamodule.val_dataloader + if self.is_overridden('test_dataloader', datamodule): + model.test_dataloader = datamodule.test_dataloader + def run_pretrain_routine(self, model: LightningModule): """Sanity check a few things before starting actual training. @@ -1168,9 +1215,7 @@ def run_pretrain_routine(self, model: LightningModule): if self.weights_summary in ModelSummary.MODES: ref_model.summarize(mode=self.weights_summary) else: - raise MisconfigurationException( - "weights_summary can be None, " + ", ".join(ModelSummary.MODES) - ) + raise MisconfigurationException("weights_summary can be None, " + ", ".join(ModelSummary.MODES)) # track model now. # if cluster resets state, the model will update with the saved weights @@ -1211,8 +1256,9 @@ def run_pretrain_routine(self, model: LightningModule): self.train() def _run_sanity_check(self, ref_model, model): - should_sanity_check = self.is_overridden('validation_step') and self.num_sanity_val_steps > 0 \ - and self.limit_val_batches > 0 + should_sanity_check = ( + self.is_overridden('validation_step') and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 + ) # run tiny validation (if validation defined) # to make sure program won't crash during val @@ -1226,10 +1272,7 @@ def _run_sanity_check(self, ref_model, model): num_loaders = len(self.val_dataloaders) max_batches = [self.num_sanity_val_steps] * num_loaders - eval_results = self._evaluate(model, - self.val_dataloaders, - max_batches, - False) + eval_results = self._evaluate(model, self.val_dataloaders, max_batches, False) # allow no returns from eval if eval_results is not None and len(eval_results) > 0: @@ -1247,11 +1290,12 @@ def _run_sanity_check(self, ref_model, model): self.running_sanity_check = False def test( - self, - model: Optional[LightningModule] = None, - test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - ckpt_path: Optional[str] = 'best', - verbose: bool = True + self, + model: Optional[LightningModule] = None, + test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + datamodule: Optional[LightningDataModule] = None, ): r""" @@ -1315,6 +1359,15 @@ def test( if self.global_rank != 0: return + # If you supply a datamodule you can't supply train_dataloader or val_dataloaders + if test_dataloaders and datamodule: + raise MisconfigurationException( + 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule' + ) + + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.__attach_datamodule(model or self.get_model(), datamodule) + self.setup('test') if model is not None: @@ -1334,7 +1387,8 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0: raise MisconfigurationException( - 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.') + 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.' + ) # load best weights if ckpt_path is not None: @@ -1343,8 +1397,10 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): ckpt_path = self.checkpoint_callback.best_model_path if len(ckpt_path) == 0: - rank_zero_warn(f'.test() found no path for the best weights, {ckpt_path}. Please ' - f'specify a path for a checkpoint .test(ckpt_path=PATH)') + rank_zero_warn( + f'.test() found no path for the best weights, {ckpt_path}. Please ' + f'specify a path for a checkpoint .test(ckpt_path=PATH)' + ) return {} ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage) @@ -1407,50 +1463,59 @@ def check_model_configuration(self, model: LightningModule): if not self.is_overridden('training_step', model): raise MisconfigurationException( 'No `training_step()` method defined. Lightning `Trainer` expects as minimum a' - ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.') + ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' + ) if not self.is_overridden('train_dataloader', model): raise MisconfigurationException( 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' - ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.') + ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' + ) if not self.is_overridden('configure_optimizers', model): raise MisconfigurationException( 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' - ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.') + ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' + ) # Check val_dataloader, validation_step and validation_epoch_end if self.is_overridden('val_dataloader', model): if not self.is_overridden('validation_step', model): - raise MisconfigurationException('You have passed in a `val_dataloader()`' - ' but have not defined `validation_step()`.') + raise MisconfigurationException( + 'You have passed in a `val_dataloader()`' ' but have not defined `validation_step()`.' + ) else: if not self.is_overridden('validation_epoch_end', model): rank_zero_warn( 'You have defined a `val_dataloader()` and have defined a `validation_step()`,' ' you may also want to define `validation_epoch_end()` for accumulating stats.', - RuntimeWarning + RuntimeWarning, ) else: if self.is_overridden('validation_step', model): - raise MisconfigurationException('You have defined `validation_step()`,' - ' but have not passed in a `val_dataloader()`.') + raise MisconfigurationException( + 'You have defined `validation_step()`,' ' but have not passed in a `val_dataloader()`.' + ) # Check test_dataloader, test_step and test_epoch_end if self.is_overridden('test_dataloader', model): if not self.is_overridden('test_step', model): - raise MisconfigurationException('You have passed in a `test_dataloader()`' - ' but have not defined `test_step()`.') + raise MisconfigurationException( + 'You have passed in a `test_dataloader()`' ' but have not defined `test_step()`.' + ) else: if not self.is_overridden('test_epoch_end', model): rank_zero_warn( 'You have defined a `test_dataloader()` and have defined a `test_step()`, you may also want to' - ' define `test_epoch_end()` for accumulating stats.', RuntimeWarning + ' define `test_epoch_end()` for accumulating stats.', + RuntimeWarning, ) else: if self.testing and self.is_overridden('test_step', model): - raise MisconfigurationException('You have defined `test_step()` but did not' - ' implement `test_dataloader` nor passed in `.test(test_dataloader)`.') + raise MisconfigurationException( + 'You have defined `test_step()` but did not' + ' implement `test_dataloader` nor passed in `.test(test_dataloader)`.' + ) def barrier(self, name): if self.use_ddp or self.use_ddp2: @@ -1471,6 +1536,7 @@ class _PatchDataLoader(object): dataloader: Dataloader object to return when called. """ + def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): self.dataloader = dataloader diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py new file mode 100644 index 0000000000000..49de3cef56527 --- /dev/null +++ b/tests/base/datamodules.py @@ -0,0 +1,30 @@ +from torch.utils.data import random_split, DataLoader + +from pytorch_lightning import LightningDataModule +from tests.base.datasets import MNIST + + +class MNISTDataModule(LightningDataModule): + + def __init__(self, data_dir: str = './'): + super(MNISTDataModule, self).__init__() + self.data_dir = data_dir + + def prepare_data(self): + MNIST(self.data_dir, train=True, download=True) + MNIST(self.data_dir, train=False, download=True) + + def setup(self): + mnist_full = MNIST(self.data_dir, train=True, download=False) + self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) + self.dims = tuple(self.mnist_train[0][0].shape) + self.mnist_test = MNIST(self.data_dir, train=False, download=False) + + def train_dataloader(self): + return DataLoader(self.mnist_train, batch_size=32) + + def val_dataloader(self): + return DataLoader(self.mnist_val, batch_size=32) + + def test_dataloader(self): + return DataLoader(self.mnist_test, batch_size=32) diff --git a/tests/trainer/test_datamodules.py b/tests/trainer/test_datamodules.py new file mode 100644 index 0000000000000..40bda9e0e62c7 --- /dev/null +++ b/tests/trainer/test_datamodules.py @@ -0,0 +1,41 @@ +import pickle +from argparse import ArgumentParser + +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import EvalModelTemplate +from tests.base.datamodules import MNISTDataModule + + +def test_base_datamodule(tmpdir): + dm = MNISTDataModule() + dm.prepare_data() + dm.setup() + + +def test_dm_add_argparse_args(tmpdir): + parser = ArgumentParser() + parser = MNISTDataModule.add_argparse_args(parser) + args = parser.parse_args(['--data_dir', './my_data']) + assert args.data_dir == './my_data' + + +def test_dm_init_from_argparse_args(tmpdir): + parser = ArgumentParser() + parser = MNISTDataModule.add_argparse_args(parser) + args = parser.parse_args(['--data_dir', './my_data']) + dm = MNISTDataModule.from_argparse_args(args) + dm.prepare_data() + dm.setup() + + +def test_dm_pickle_after_init(tmpdir): + dm = MNISTDataModule() + pickle.dumps(dm) + + +def test_dm_pickle_after_setup(tmpdir): + dm = MNISTDataModule() + dm.prepare_data() + dm.setup() + pickle.dumps(dm)