From 12a85b3f74d300fe673ca70090fc6e629ae6bd3a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 01:54:23 +0100 Subject: [PATCH 1/9] Pass {fit,validate,test,predict} to setup() --- pytorch_lightning/callbacks/base.py | 4 +- pytorch_lightning/core/datamodule.py | 58 ++++++++----- pytorch_lightning/core/hooks.py | 8 +- pytorch_lightning/trainer/callback_hook.py | 24 +++--- pytorch_lightning/trainer/model_hooks.py | 6 -- pytorch_lightning/trainer/states.py | 16 ++-- pytorch_lightning/trainer/trainer.py | 57 ++++++------- tests/callbacks/test_callbacks.py | 35 ++++---- tests/core/test_datamodules.py | 97 ++++++++++++---------- tests/helpers/boring_model.py | 4 +- 10 files changed, 168 insertions(+), 141 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index d53acf0f7030d..494d94cf446de 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -34,11 +34,11 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModul pass def setup(self, trainer, pl_module: LightningModule, stage: str) -> None: - """Called when fit or test begins""" + """Called when fit, validate, test, predict, or tune begins""" pass def teardown(self, trainer, pl_module: LightningModule, stage: str) -> None: - """Called when fit or test ends""" + """Called when fit, validate, test, predict, or tune ends""" pass def on_init_start(self, trainer) -> None: diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 29b93abe3e6a1..31c05e3bcc4c4 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -55,10 +55,10 @@ def __call__(cls, *args, **kwargs): def track_data_hook_calls(fn): """A decorator that checks if prepare_data/setup have been called. - - When dm.prepare_data() is called, dm.has_prepared_data gets set to True - - When dm.setup('fit') is called, dm.has_setup_fit gets set to True - - When dm.setup('test') is called, dm.has_setup_test gets set to True - - When dm.setup() is called without stage arg, both dm.has_setup_fit and dm.has_setup_test get set to True + - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True + - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True + - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}`` + it's corresponding `dm_has_setup_{stage}` gets set to True Args: fn (function): Function that will be tracked to see if it has been called. @@ -77,15 +77,15 @@ def wrapped_fn(*args, **kwargs): if fn.__name__ == "setup": # Get stage either by grabbing from args or checking kwargs. - # If not provided, set call status of 'fit' and 'test' to True. + # If not provided, set call status of 'fit', 'validate', and 'test' to True. # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() stage = args[1] if len(args) > 1 else kwargs.get("stage", None) - if stage == "fit" or stage is None: - obj._has_setup_fit = True - - if stage == "test" or stage is None: - obj._has_setup_test = True + if stage is None: + for s in ("fit", "validate", "test"): + setattr(obj, f"_has_setup_{s}", True) + else: + setattr(obj, f"_has_setup_{stage}", True) if fn.__name__ == "prepare_data": obj._has_prepared_data = True @@ -156,7 +156,9 @@ def __init__( # Private attrs to keep track of whether or not data hooks have been called yet self._has_prepared_data = False self._has_setup_fit = False + self._has_setup_validate = False self._has_setup_test = False + self._has_setup_predict = False @property def train_transforms(self): @@ -214,32 +216,50 @@ def size(self, dim=None) -> Union[Tuple, int]: return self.dims @property - def has_prepared_data(self): - """Return bool letting you know if datamodule.prepare_data() has been called or not. + def has_prepared_data(self) -> bool: + """Return bool letting you know if ``datamodule.prepare_data()`` has been called or not. Returns: - bool: True if datamodule.prepare_data() has been called. False by default. + bool: True if ``datamodule.prepare_data()`` has been called. False by default. """ return self._has_prepared_data @property - def has_setup_fit(self): - """Return bool letting you know if datamodule.setup('fit') has been called or not. + def has_setup_fit(self) -> bool: + """Return bool letting you know if ``datamodule.setup('fit')`` has been called or not. Returns: - bool: True if datamodule.setup('fit') has been called. False by default. + bool: True ``if datamodule.setup('fit')`` has been called. False by default. """ return self._has_setup_fit @property - def has_setup_test(self): - """Return bool letting you know if datamodule.setup('test') has been called or not. + def has_setup_validate(self) -> bool: + """Return bool letting you know if ``datamodule.setup('validate')`` has been called or not. + + Returns: + bool: True if ``datamodule.setup('validate')`` has been called. False by default. + """ + return self._has_setup_validate + + @property + def has_setup_test(self) -> bool: + """Return bool letting you know if ``datamodule.setup('test')`` has been called or not. Returns: - bool: True if datamodule.setup('test') has been called. False by default. + bool: True if ``datamodule.setup('test')`` has been called. False by default. """ return self._has_setup_test + @property + def has_setup_predict(self) -> bool: + """Return bool letting you know if ``datamodule.setup('predict')`` has been called or not. + + Returns: + bool: True if ``datamodule.setup('predict')`` has been called. False by default. + """ + return self._has_setup_predict + @abstractmethod def prepare_data(self, *args, **kwargs): pass diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 604803365298c..a6567e3d52f0f 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -27,12 +27,12 @@ class ModelHooks: def setup(self, stage: str) -> None: """ - Called at the beginning of fit and test. + Called at the beginning of fit (train + validate), validate, test, predict, or tune. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. Args: - stage: either 'fit' or 'test' + stage: either ``'fit'``, ``'validate'``, ``'test'``, ``'predict'``, or ``'tune'`` Example:: @@ -55,10 +55,10 @@ def setup(stage): def teardown(self, stage: str) -> None: """ - Called at the end of fit and test. + Called at the end of fit (train + validate), validate, test, predict, or tune. Args: - stage: either 'fit' or 'test' + stage: either ``'fit'``, ``'validate'``, ``'test'``, ``'predict'``, or ``'tune'`` """ def on_fit_start(self) -> None: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8f9fc3ad930b0..71433429f7c03 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -29,18 +29,18 @@ class TrainerCallbackHookMixin(ABC): callbacks: List[Callback] = [] lightning_module: LightningModule - def on_before_accelerator_backend_setup(self, model): - """Called in the beginning of fit and test""" + def on_before_accelerator_backend_setup(self, model: LightningModule) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) - def setup(self, model, stage: str): - """Called in the beginning of fit and test""" + def setup(self, model: LightningModule, stage: str) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.setup(self, model, stage) - def teardown(self, stage: str): - """Called at the end of fit and test""" + def teardown(self, stage: str) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.teardown(self, self.lightning_module, stage) @@ -124,15 +124,15 @@ def on_train_end(self): for callback in self.callbacks: callback.on_train_end(self, self.lightning_module) - def on_pretrain_routine_start(self, model): - """Called when the train begins.""" + def on_pretrain_routine_start(self) -> None: + """Called when the pre-train routine begins.""" for callback in self.callbacks: - callback.on_pretrain_routine_start(self, model) + callback.on_pretrain_routine_start(self, self.lightning_module) - def on_pretrain_routine_end(self, model): - """Called when the train ends.""" + def on_pretrain_routine_end(self) -> None: + """Called when the pre-train routine ends.""" for callback in self.callbacks: - callback.on_pretrain_routine_end(self, model) + callback.on_pretrain_routine_end(self, self.lightning_module) def on_batch_start(self): """Called when the training batch begins.""" diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index 7e3d6cc78320c..e98ebf088a8dc 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -22,12 +22,6 @@ class TrainerModelHooksMixin(ABC): lightning_module: LightningModule - def is_function_implemented(self, f_name, model=None): - if model is None: - model = self.lightning_module - f_op = getattr(model, f_name, None) - return callable(f_op) - def has_arg(self, f_name, arg_name): model = self.lightning_module f_op = getattr(model, f_name, None) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index d0c2ded659f67..2688fb6754977 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -27,14 +27,14 @@ class TrainerState(LightningEnum): >>> TrainerState.FINISHED == 'finished' True """ - INITIALIZING = 'INITIALIZING' # trainer creation - FITTING = 'FITTING' # trainer.fit() - VALIDATING = 'VALIDATING' # trainer.validate() - TESTING = 'TESTING' # trainer.test() - PREDICTING = 'PREDICTING' # trainer.predict() - TUNING = 'TUNING' # trainer.tune() - FINISHED = 'FINISHED' - INTERRUPTED = 'INTERRUPTED' + INITIALIZING = 'initializing' # trainer creation + FITTING = 'fit' # trainer.fit() + VALIDATING = 'validate' # trainer.validate() + TESTING = 'test' # trainer.test() + PREDICTING = 'predict' # trainer.predict() + TUNING = 'tune' # trainer.tune() + FINISHED = 'finished' + INTERRUPTED = 'interrupted' @property def stopped(self) -> bool: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cc1964f07039b..7cd666b17ca7b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -443,7 +443,7 @@ def fit( # ---------------------------- self.call_setup_hook(model) self.call_hook("on_before_accelerator_backend_setup", model) - self.accelerator.setup(self, model) + self.accelerator.setup(self, model) # note: this sets up self.lightning_module self.setup_trainer(model) # ---------------------------- @@ -473,7 +473,8 @@ def fit( # TRAIN # ---------------------------- # hook - self.call_hook("on_fit_start") + if self.state == TrainerState.FITTING: + self.call_hook("on_fit_start") # plugin will setup fitting (e.g. ddp will launch child processes) self.pre_dispatch() @@ -488,12 +489,11 @@ def fit( # POST-Training CLEAN UP # ---------------------------- # hook - self.call_hook('on_fit_end') + if self.state == TrainerState.FITTING: + self.call_hook('on_fit_end') - # hook - self.teardown('fit') - if self.is_function_implemented('teardown'): - model.teardown('fit') + # teardown + self.call_teardown_hook(model) if self.state != TrainerState.INTERRUPTED: self.state = TrainerState.FINISHED @@ -541,9 +541,8 @@ def _pre_training_routine(self): # on pretrain routine start ref_model = self.lightning_module - self.on_pretrain_routine_start(ref_model) - if self.is_function_implemented("on_pretrain_routine_start"): - ref_model.on_pretrain_routine_start() + self.on_pretrain_routine_start() + ref_model.on_pretrain_routine_start() # print model summary if self.is_global_zero and self.weights_summary is not None and not self.testing: @@ -556,9 +555,8 @@ def _pre_training_routine(self): self.checkpoint_connector.restore_weights() # on pretrain routine end - self.on_pretrain_routine_end(ref_model) - if self.is_function_implemented("on_pretrain_routine_end"): - ref_model.on_pretrain_routine_end() + self.on_pretrain_routine_end() + ref_model.on_pretrain_routine_end() def run_train(self) -> None: @@ -880,8 +878,6 @@ def test( self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders) ) - self.teardown('test') - assert self.state.stopped self.testing = False @@ -929,10 +925,6 @@ def __evaluate_using_weights( # run test results = self.fit(model) - # teardown - if self.is_function_implemented('teardown', model=model): - model.teardown('test') - return results def __evaluate_given_model(self, model, dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None): @@ -944,10 +936,6 @@ def __evaluate_given_model(self, model, dataloaders: Optional[Union[DataLoader, # sets up testing so we short circuit to eval results = self.fit(model) - # teardown - if self.is_function_implemented('teardown', model=model): - model.teardown('test') - return results def predict( @@ -1035,17 +1023,26 @@ def tune( assert self.state.stopped self.tuning = False - def call_setup_hook(self, model): - # call setup after the ddp process has connected - stage_name = 'test' if self.evaluating else 'fit' + def call_setup_hook(self, model: LightningModule) -> None: + assert self.state.running, f"TrainerState: {self.state}" + state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + state = state.value if self.datamodule is not None: - called = getattr(self.datamodule, f'has_setup_{stage_name}') + called = getattr(self.datamodule, f'has_setup_{state}') if not called: - self.datamodule.setup(stage_name) + self.datamodule.setup(state) + + self.setup(model, state) + model.setup(state) + + def call_teardown_hook(self, model: LightningModule) -> None: + assert self.state.running, f"TrainerState: {self.state}" + state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + state = state.value - self.setup(model, stage_name) - model.setup(stage_name) + self.teardown(state) + model.teardown(state) def _reset_result_and_set_hook_fx_name(self, hook_name): # on_before_zero_grad is called within training_step diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 8a25ecc9f983b..2426348f770bf 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -19,29 +19,20 @@ @mock.patch("torch.save") # need to mock torch.save or we get pickle error -def test_trainer_callback_system(_, tmpdir): - """Test the callback system.""" +def test_trainer_callback_system_fit(_, tmpdir): + """Test the callback system for fit.""" model = BoringModel() - callback_mock = MagicMock() - - trainer_options = dict( + trainer = Trainer( default_root_dir=tmpdir, callbacks=[callback_mock], max_epochs=1, limit_val_batches=1, limit_train_batches=3, - limit_test_batches=2, progress_bar_refresh_rate=0, ) - # no call yet - callback_mock.assert_not_called() - - # fit model - trainer = Trainer(**trainer_options) - # check that only the to calls exists assert trainer.callbacks[0] == callback_mock assert callback_mock.method_calls == [ @@ -49,6 +40,7 @@ def test_trainer_callback_system(_, tmpdir): call.on_init_end(trainer), ] + # fit model trainer.fit(model) assert callback_mock.method_calls == [ @@ -104,8 +96,20 @@ def test_trainer_callback_system(_, tmpdir): call.teardown(trainer, model, 'fit'), ] - callback_mock.reset_mock() - trainer = Trainer(**trainer_options) + +def test_trainer_callback_system_test(tmpdir): + """Test the callback system for test.""" + + model = BoringModel() + callback_mock = MagicMock() + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[callback_mock], + max_epochs=1, + limit_test_batches=2, + progress_bar_refresh_rate=0, + ) + trainer.test(model) assert callback_mock.method_calls == [ @@ -113,7 +117,6 @@ def test_trainer_callback_system(_, tmpdir): call.on_init_end(trainer), call.setup(trainer, model, 'test'), call.on_before_accelerator_backend_setup(trainer, model), - call.on_fit_start(trainer, model), call.on_test_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), @@ -123,8 +126,6 @@ def test_trainer_callback_system(_, tmpdir): call.on_test_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_test_end(trainer, model), - call.on_fit_end(trainer, model), - call.teardown(trainer, model, 'fit'), call.teardown(trainer, model, 'test'), ] diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 866bffcdd7441..e1b4301842ecd 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -17,6 +17,7 @@ from unittest import mock from unittest.mock import PropertyMock +import pytest import torch import torch.nn.functional as F @@ -108,13 +109,13 @@ def prepare_data(self, *args, **kwargs): dm.prepare_data() -def test_base_datamodule(tmpdir): +def test_helper_boringdatamodule(tmpdir): dm = BoringDataModule() dm.prepare_data() dm.setup() -def test_base_datamodule_with_verbose_setup(tmpdir): +def test_helper_boringdatamodule_with_verbose_setup(tmpdir): dm = BoringDataModule() dm.prepare_data() dm.setup('fit') @@ -123,55 +124,67 @@ def test_base_datamodule_with_verbose_setup(tmpdir): def test_data_hooks_called(tmpdir): dm = BoringDataModule() - assert dm.has_prepared_data is False - assert dm.has_setup_fit is False - assert dm.has_setup_test is False + assert not dm.has_prepared_data + assert not dm.has_setup_fit + assert not dm.has_setup_test + assert not dm.has_setup_validate + assert not dm.has_setup_predict dm.prepare_data() - assert dm.has_prepared_data is True - assert dm.has_setup_fit is False - assert dm.has_setup_test is False + assert dm.has_prepared_data + assert not dm.has_setup_fit + assert not dm.has_setup_test + assert not dm.has_setup_validate + assert not dm.has_setup_predict dm.setup() - assert dm.has_prepared_data is True - assert dm.has_setup_fit is True - assert dm.has_setup_test is True + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_test + assert dm.has_setup_validate + assert not dm.has_setup_predict -def test_data_hooks_called_verbose(tmpdir): +@pytest.mark.parametrize("use_kwarg", (False, True)) +def test_data_hooks_called_verbose(tmpdir, use_kwarg): dm = BoringDataModule() - assert dm.has_prepared_data is False - assert dm.has_setup_fit is False - assert dm.has_setup_test is False + assert not dm.has_prepared_data + assert not dm.has_setup_fit + assert not dm.has_setup_test dm.prepare_data() - assert dm.has_prepared_data is True - assert dm.has_setup_fit is False - assert dm.has_setup_test is False - - dm.setup('fit') - assert dm.has_prepared_data is True - assert dm.has_setup_fit is True - assert dm.has_setup_test is False - - dm.setup('test') - assert dm.has_prepared_data is True - assert dm.has_setup_fit is True - assert dm.has_setup_test is True - - -def test_data_hooks_called_with_stage_kwarg(tmpdir): - dm = BoringDataModule() - dm.prepare_data() - assert dm.has_prepared_data is True - - dm.setup(stage='fit') - assert dm.has_setup_fit is True - assert dm.has_setup_test is False - - dm.setup(stage='test') - assert dm.has_setup_fit is True - assert dm.has_setup_test is True + assert dm.has_prepared_data + assert not dm.has_setup_fit + assert not dm.has_setup_test + assert not dm.has_setup_predict + + dm.setup(stage='fit') if use_kwarg else dm.setup('fit') + assert dm.has_prepared_data + assert dm.has_setup_fit + assert not dm.has_setup_validate + assert not dm.has_setup_test + assert not dm.has_setup_predict + + dm.setup(stage='validate') if use_kwarg else dm.setup('validate') + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_validate + assert not dm.has_setup_test + assert not dm.has_setup_predict + + dm.setup(stage='test') if use_kwarg else dm.setup('test') + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_validate + assert dm.has_setup_test + assert not dm.has_setup_predict + + dm.setup(stage='predict') if use_kwarg else dm.setup('predict') + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_validate + assert dm.has_setup_test + assert dm.has_setup_predict def test_dm_add_argparse_args(tmpdir): diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index ea26310a45315..6ef2518bbef11 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -151,9 +151,11 @@ def prepare_data(self): def setup(self, stage: Optional[str] = None): if stage == "fit" or stage is None: self.random_train = Subset(self.random_full, indices=range(64)) - self.random_val = Subset(self.random_full, indices=range(64, 128)) self.dims = self.random_train[0].shape + if stage in ("fit", "validate") or stage is None: + self.random_val = Subset(self.random_full, indices=range(64, 128)) + if stage == "test" or stage is None: self.random_test = Subset(self.random_full, indices=range(128, 192)) self.dims = getattr(self, "dims", self.random_test[0].shape) From d49ccd1b9fb0afadaa28c4bead0e0cb7e5b1fc91 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 02:43:12 +0100 Subject: [PATCH 2/9] Fix doctest --- pytorch_lightning/trainer/states.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 2688fb6754977..33a2326c518d5 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -21,10 +21,10 @@ class TrainerState(LightningEnum): functions such as `trainer.fit()` and `trainer.test(). >>> # you can compare the type with a string - >>> TrainerState.FITTING == 'FITTING' + >>> TrainerState.FITTING == 'fit' True >>> # which is case insensitive - >>> TrainerState.FINISHED == 'finished' + >>> TrainerState.FINISHED == 'FINISHED' True """ INITIALIZING = 'initializing' # trainer creation From 23db13507878a60cb17844f6133c5b7adf9fa9ca Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 03:24:29 +0100 Subject: [PATCH 3/9] stage: Optional[str] = None --- pytorch_lightning/callbacks/base.py | 6 +++--- pytorch_lightning/core/hooks.py | 8 ++++---- pytorch_lightning/trainer/callback_hook.py | 6 +++--- pytorch_lightning/trainer/trainer.py | 18 ++++++++++-------- tests/models/test_hooks.py | 16 ++++++---------- 5 files changed, 26 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 494d94cf446de..0ba1fd4ff7785 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -17,7 +17,7 @@ """ import abc -from typing import Any, Dict +from typing import Any, Dict, Optional from pytorch_lightning.core.lightning import LightningModule @@ -33,11 +33,11 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModul """Called before accelerator is being setup""" pass - def setup(self, trainer, pl_module: LightningModule, stage: str) -> None: + def setup(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: """Called when fit, validate, test, predict, or tune begins""" pass - def teardown(self, trainer, pl_module: LightningModule, stage: str) -> None: + def teardown(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: """Called when fit, validate, test, predict, or tune ends""" pass diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index a6567e3d52f0f..9826f9d44ac2c 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -25,14 +25,14 @@ class ModelHooks: """Hooks to be used in LightningModule.""" - def setup(self, stage: str) -> None: + def setup(self, stage: Optional[str] = None) -> None: """ Called at the beginning of fit (train + validate), validate, test, predict, or tune. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, ``'predict'``, or ``'tune'`` + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` Example:: @@ -53,12 +53,12 @@ def setup(stage): """ - def teardown(self, stage: str) -> None: + def teardown(self, stage: Optional[str] = None) -> None: """ Called at the end of fit (train + validate), validate, test, predict, or tune. Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, ``'predict'``, or ``'tune'`` + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` """ def on_fit_start(self) -> None: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 71433429f7c03..f174cd725bd36 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,7 +15,7 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import Any, Callable, Dict, List, Type +from typing import Any, Callable, Dict, List, Type, Optional from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule @@ -34,12 +34,12 @@ def on_before_accelerator_backend_setup(self, model: LightningModule) -> None: for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) - def setup(self, model: LightningModule, stage: str) -> None: + def setup(self, model: LightningModule, stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.setup(self, model, stage) - def teardown(self, stage: str) -> None: + def teardown(self, stage: Optional[str] = None) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.teardown(self, self.lightning_module, stage) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7cd666b17ca7b..d58de7d803146 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1031,18 +1031,20 @@ def call_setup_hook(self, model: LightningModule) -> None: if self.datamodule is not None: called = getattr(self.datamodule, f'has_setup_{state}') if not called: - self.datamodule.setup(state) + self.datamodule.setup(stage=state) - self.setup(model, state) - model.setup(state) + self.setup(model, stage=state) + model.setup(stage=state) def call_teardown_hook(self, model: LightningModule) -> None: - assert self.state.running, f"TrainerState: {self.state}" - state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state - state = state.value + if self.state.running: + state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + state = state.value + else: + state = None - self.teardown(state) - model.teardown(state) + self.teardown(stage=state) + model.teardown(stage=state) def _reset_result_and_set_hook_fx_name(self, hook_name): # on_before_zero_grad is called within training_step diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 1a7803800b384..7c53925bd7cc4 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -404,7 +404,7 @@ def on_test_end(self): self.called.append(inspect.currentframe().f_code.co_name) super().on_test_end() - def teardown(self, stage: str): + def teardown(self, stage=None): self.called.append(inspect.currentframe().f_code.co_name) super().teardown(stage) @@ -420,12 +420,12 @@ def teardown(self, stage: str): limit_train_batches=2, limit_test_batches=1, progress_bar_refresh_rate=0, + weights_summary=None, ) assert model.called == [] trainer.fit(model) - expected = [ 'on_fit_start', 'on_pretrain_routine_start', @@ -469,11 +469,10 @@ def teardown(self, stage: str): assert model.called == expected - model2 = HookedModel() - trainer.test(model2) + model = HookedModel() + trainer.test(model, verbose=False) expected = [ - 'on_fit_start', 'on_test_model_eval', 'on_test_start', 'on_test_epoch_start', @@ -483,9 +482,6 @@ def teardown(self, stage: str): 'on_epoch_end', 'on_test_end', 'on_test_model_train', - 'on_fit_end', - 'teardown', # for 'fit' - 'teardown', # for 'test' + 'teardown', ] - - assert model2.called == expected + assert model.called == expected From 84f5fdb8e6b6b4254a0f635281c5356756d00ba5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 03:26:48 +0100 Subject: [PATCH 4/9] Trailing whitespace --- tests/core/test_datamodules.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index e1b4301842ecd..ab51a87329e2f 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -131,17 +131,17 @@ def test_data_hooks_called(tmpdir): assert not dm.has_setup_predict dm.prepare_data() - assert dm.has_prepared_data + assert dm.has_prepared_data assert not dm.has_setup_fit assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict dm.setup() - assert dm.has_prepared_data - assert dm.has_setup_fit - assert dm.has_setup_test - assert dm.has_setup_validate + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_test + assert dm.has_setup_validate assert not dm.has_setup_predict @@ -153,21 +153,21 @@ def test_data_hooks_called_verbose(tmpdir, use_kwarg): assert not dm.has_setup_test dm.prepare_data() - assert dm.has_prepared_data + assert dm.has_prepared_data assert not dm.has_setup_fit assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='fit') if use_kwarg else dm.setup('fit') assert dm.has_prepared_data - assert dm.has_setup_fit + assert dm.has_setup_fit assert not dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='validate') if use_kwarg else dm.setup('validate') - assert dm.has_prepared_data - assert dm.has_setup_fit + assert dm.has_prepared_data + assert dm.has_setup_fit assert dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict @@ -180,11 +180,11 @@ def test_data_hooks_called_verbose(tmpdir, use_kwarg): assert not dm.has_setup_predict dm.setup(stage='predict') if use_kwarg else dm.setup('predict') - assert dm.has_prepared_data - assert dm.has_setup_fit - assert dm.has_setup_validate - assert dm.has_setup_test - assert dm.has_setup_predict + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_validate + assert dm.has_setup_test + assert dm.has_setup_predict def test_dm_add_argparse_args(tmpdir): From 188b9feae8b386114d792113758afb51fa9c5931 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 03:50:52 +0100 Subject: [PATCH 5/9] Update docs and CHANGELOG --- CHANGELOG.md | 3 +++ docs/source/extensions/datamodules.rst | 12 ++++++------ docs/source/starter/introduction_guide.rst | 8 ++++---- docs/source/starter/new-project.rst | 2 +- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f8f7a08b089b..f6ef0d56b3792 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) +- Changed `setup()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + ### Deprecated diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index a6c083dc61fcf..85134fda06fa2 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -80,7 +80,7 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa self.data_dir = data_dir self.batch_size = batch_size - def setup(self, stage=None): + def setup(self, stage: Optional[str] = None): self.mnist_test = MNIST(self.data_dir, train=False) mnist_full = MNIST(self.data_dir, train=True) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) @@ -138,7 +138,7 @@ Here's a more realistic, complex DataModule that shows how much more reusable th MNIST(self.data_dir, train=True, download=True) MNIST(self.data_dir, train=False, download=True) - def setup(self, stage=None): + def setup(self, stage: Optional[str] = None): # Assign train/val datasets for use in dataloaders if stage == 'fit' or stage is None: @@ -382,12 +382,12 @@ still ensures the method runs on the correct devices) dm = MNISTDataModule() dm.prepare_data() - dm.setup('fit') + dm.setup(stage='fit') model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab) trainer.fit(model, dm) - dm.setup('test') + dm.setup(stage='test') trainer.test(datamodule=dm) ---------------- @@ -403,7 +403,7 @@ You can of course use DataModules in plain PyTorch code as well. dm.prepare_data() # splits/transforms - dm.setup('fit') + dm.setup(stage='fit') # use data for batch in dm.train_dataloader(): @@ -412,7 +412,7 @@ You can of course use DataModules in plain PyTorch code as well. ... # lazy load test data - dm.setup('test') + dm.setup(stage='test') for batch in dm.test_dataloader(): ... diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst index 2ee31304299e0..c65894367a39e 100644 --- a/docs/source/starter/introduction_guide.rst +++ b/docs/source/starter/introduction_guide.rst @@ -240,7 +240,7 @@ In this case, it's better to group the full definition of a dataset into a `Data tokenize() build_vocab() - def setup(self): + def setup(self, stage: Optional[str] = None): # called on every GPU vocab = load_vocab() self.vocab_size = len(vocab) @@ -310,8 +310,8 @@ An alternative to using a DataModule is to defer initialization of the models mo download_data() tokenize() - def setup(self, step): - # step is either 'fit' or 'test' 90% of the time not relevant + def setup(self, stage: Optional[str] = None): + # step is either 'fit', 'validate', 'test', or 'predict'. 90% of the time not relevant data = load_data() num_classes = data.classes self.l1 = nn.Linear(..., num_classes) @@ -598,7 +598,7 @@ In this method we do all the preparation we need to do once (instead of on every MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) - def setup(self, stage): + def setup(self, stage: Optional[str] = None): # transform transform=transforms.Compose([transforms.ToTensor()]) mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform) diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index 0f1362616a9b1..23f91914063d9 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -651,7 +651,7 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning. MNIST(os.getcwd(), train=False, download=True) # OPTIONAL, called for every GPU/machine (assigning state is OK) - def setup(self, stage): + def setup(self, stage: Optional[str] = None): # transforms transform=transforms.Compose([ transforms.ToTensor(), From 37473f0c549590c5c342b53c564af7affb9fb05b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 03:52:14 +0100 Subject: [PATCH 6/9] Mention teardown --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f6ef0d56b3792..327f923a79ff1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,7 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) -- Changed `setup()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) +- Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) ### Deprecated From 0a30abf931ec5a5f1127bdf98514df6f489cb735 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 03:59:49 +0100 Subject: [PATCH 7/9] Self-review --- pytorch_lightning/core/datamodule.py | 20 ++++++++++---------- pytorch_lightning/trainer/callback_hook.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 31c05e3bcc4c4..1b6852c071fe1 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -57,8 +57,8 @@ def track_data_hook_calls(fn): - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True - - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}`` - it's corresponding `dm_has_setup_{stage}` gets set to True + - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. + Its corresponding `dm_has_setup_{stage}` attribute gets set to True Args: fn (function): Function that will be tracked to see if it has been called. @@ -226,37 +226,37 @@ def has_prepared_data(self) -> bool: @property def has_setup_fit(self) -> bool: - """Return bool letting you know if ``datamodule.setup('fit')`` has been called or not. + """Return bool letting you know if ``datamodule.setup(stage='fit')`` has been called or not. Returns: - bool: True ``if datamodule.setup('fit')`` has been called. False by default. + bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default. """ return self._has_setup_fit @property def has_setup_validate(self) -> bool: - """Return bool letting you know if ``datamodule.setup('validate')`` has been called or not. + """Return bool letting you know if ``datamodule.setup(stage='validate')`` has been called or not. Returns: - bool: True if ``datamodule.setup('validate')`` has been called. False by default. + bool: True if ``datamodule.setup(stage='validate')`` has been called. False by default. """ return self._has_setup_validate @property def has_setup_test(self) -> bool: - """Return bool letting you know if ``datamodule.setup('test')`` has been called or not. + """Return bool letting you know if ``datamodule.setup(stage='test')`` has been called or not. Returns: - bool: True if ``datamodule.setup('test')`` has been called. False by default. + bool: True if ``datamodule.setup(stage='test')`` has been called. False by default. """ return self._has_setup_test @property def has_setup_predict(self) -> bool: - """Return bool letting you know if ``datamodule.setup('predict')`` has been called or not. + """Return bool letting you know if ``datamodule.setup(stage='predict')`` has been called or not. Returns: - bool: True if ``datamodule.setup('predict')`` has been called. False by default. + bool: True if ``datamodule.setup(stage='predict')`` has been called. False by default. """ return self._has_setup_predict diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index f174cd725bd36..5aa9f1a44276b 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -40,7 +40,7 @@ def setup(self, model: LightningModule, stage: Optional[str]) -> None: callback.setup(self, model, stage) def teardown(self, stage: Optional[str] = None) -> None: - """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" + """Called at the end of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.teardown(self, self.lightning_module, stage) From 0e9d69c35356824d5cc1b8e986c850ad71de50af Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 14:39:30 +0100 Subject: [PATCH 8/9] Address Borda's comments --- docs/source/conf.py | 1 + pytorch_lightning/trainer/model_hooks.py | 10 +++++++++- pytorch_lightning/trainer/trainer.py | 3 +-- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 813d5ee978821..ccf824bb37d9b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -371,6 +371,7 @@ def package_list_from_file(file): doctest_global_setup = """ import importlib import os +from typing import Optional import torch from torch import nn import pytorch_lightning as pl diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index e98ebf088a8dc..b924675d8505c 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -14,6 +14,7 @@ import inspect from abc import ABC +from typing import Optional from pytorch_lightning.core.lightning import LightningModule @@ -22,7 +23,14 @@ class TrainerModelHooksMixin(ABC): lightning_module: LightningModule - def has_arg(self, f_name, arg_name): + def is_function_implemented(self, f_name: str, model: Optional[LightningModule] = None) -> bool: + # note: currently unused - kept as it is public + if model is None: + model = self.lightning_module + f_op = getattr(model, f_name, None) + return callable(f_op) + + def has_arg(self, f_name: str, arg_name: str) -> bool: model = self.lightning_module f_op = getattr(model, f_name, None) return arg_name in inspect.signature(f_op).parameters diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d58de7d803146..45fc40731b545 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1025,8 +1025,8 @@ def tune( def call_setup_hook(self, model: LightningModule) -> None: assert self.state.running, f"TrainerState: {self.state}" + # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state - state = state.value if self.datamodule is not None: called = getattr(self.datamodule, f'has_setup_{state}') @@ -1039,7 +1039,6 @@ def call_setup_hook(self, model: LightningModule) -> None: def call_teardown_hook(self, model: LightningModule) -> None: if self.state.running: state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state - state = state.value else: state = None From 60a479e6de9170be356c11cc85f64006ea020681 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sun, 7 Mar 2021 18:11:33 +0100 Subject: [PATCH 9/9] Update CHANGELOG --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 327f923a79ff1..b787b35dbaace 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -101,6 +101,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324)) +- Fixed `.teardown(stage='fit')` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + +- Fixed `.on_fit_{start,end}()` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + - Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260))