From ee140aff2637e3b94a8542134629a656b41a5952 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 20 Jul 2021 22:46:57 +0200 Subject: [PATCH 01/29] Always use `trainer.call_hook` --- pytorch_lightning/core/lightning.py | 21 +++- .../loops/dataloader/evaluation_loop.py | 5 +- .../loops/epoch/training_epoch_loop.py | 8 +- pytorch_lightning/trainer/callback_hook.py | 4 +- .../trainer/connectors/callback_connector.py | 12 +- .../trainer/connectors/data_connector.py | 4 +- .../logger_connector/fx_validator.py | 17 +-- pytorch_lightning/trainer/data_loading.py | 11 +- pytorch_lightning/trainer/optimizers.py | 2 +- pytorch_lightning/trainer/trainer.py | 38 +++--- .../connectors/test_callback_connector.py | 8 +- .../trainer/logging_/test_logger_connector.py | 112 +++++++++++++++++- 12 files changed, 182 insertions(+), 60 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 90a96819061c0..df8ba49ce2452 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -400,7 +400,26 @@ def log( on_epoch = self.__auto_choose_log_on_epoch(on_epoch) results = self.trainer._results - assert results is not None + if results is None: + caller = inspect.stack()[1][3] + raise MisconfigurationException( + f'You are trying to `self.log()` inside `{caller}` but the loop `ResultCollection` is not registered' + ' yet. This is most likely because you are trying to log in a `predict` hook, but it does' + ' not support logging.' + ) + if self.trainer.lightning_module is None: + # this is to avoid `lightning_module.log` in callback hooks which have the lightning module available as a + # parameter but the reference in the trainer has not been set yet + caller = inspect.stack()[1][3] + raise MisconfigurationException( + f'You are trying to `self.log()` inside `{caller}` but the' + ' `LightningModule` is not registered yet for the trainer.' + ) + if self._current_fx_name is None: + caller = inspect.stack()[1][3] + raise MisconfigurationException( + f'You are trying to `self.log()` inside `{caller}` but it is not managed by the `Trainer` control flow' + ) assert self._current_fx_name is not None FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 8eacd73607665..94867750c025e 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -189,11 +189,10 @@ def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: def on_evaluation_model_eval(self) -> None: """Sets model to eval mode""" - model_ref = self.trainer.lightning_module if self.trainer.testing: - model_ref.on_test_model_eval() + self.trainer.call_hook('on_test_model_eval') else: - model_ref.on_validation_model_eval() + self.trainer.call_hook('on_validation_model_eval') def on_evaluation_model_train(self) -> None: """Sets model to train mode""" diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index a79b58efe9d31..80d874654a4ee 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -273,9 +273,11 @@ def _on_train_epoch_end_hook(self, processed_epoch_output: List[List[STEP_OUTPUT with self.trainer.profiler.profile(hook_name): # first call trainer hook - if hasattr(self.trainer, hook_name): + if hook_name not in ("setup", ) and hasattr(self.trainer, hook_name): trainer_hook = getattr(self.trainer, hook_name) - trainer_hook(processed_epoch_output) + if trainer_hook is not None: + # `train_dataloader` is a function for the `LightningModule` but an attribute for the `Trainer` + trainer_hook(processed_epoch_output) # next call hook in lightningModule model_ref = self.trainer.lightning_module @@ -292,7 +294,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output: List[List[STEP_OUTPUT model_ref.on_train_epoch_end() # call the accelerator hook - if hasattr(self.trainer.accelerator, hook_name): + if hook_name not in ("setup", "teardown") and hasattr(self.trainer.accelerator, hook_name): accelerator_hook = getattr(self.trainer.accelerator, hook_name) accelerator_hook() diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 63c23d50fa772..4b2a7d7b99565 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -41,10 +41,10 @@ def on_before_accelerator_backend_setup(self, model: 'pl.LightningModule') -> No for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) - def configure_sharded_model(self, model: 'pl.LightningModule') -> None: + def configure_sharded_model(self) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: - callback.on_configure_sharded_model(self, model) + callback.on_configure_sharded_model(self, self.lightning_module) def setup(self, model: 'pl.LightningModule', stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index c4edb4c10017d..d8e866f8b1aca 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -15,7 +15,6 @@ from datetime import timedelta from typing import Dict, List, Optional, Union -import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.utilities import rank_zero_info @@ -136,8 +135,7 @@ def attach_model_logging_functions(self, model): callback.log = model.log callback.log_dict = model.log_dict - @staticmethod - def _attach_model_callbacks(model: 'pl.LightningModule', trainer) -> None: + def _attach_model_callbacks(self) -> None: """ Attaches the callbacks defined in the model. If a callback returned by the model's configure_callback method has the same type as one or several @@ -150,11 +148,11 @@ def _attach_model_callbacks(model: 'pl.LightningModule', trainer) -> None: :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_callbacks`. trainer: The trainer on which the callbacks get attached/merged. """ - model_callbacks = model.configure_callbacks() + model_callbacks = self.trainer.call_hook('configure_callbacks') if not model_callbacks: return model_callback_types = {type(c) for c in model_callbacks} - trainer_callback_types = {type(c) for c in trainer.callbacks} + trainer_callback_types = {type(c) for c in self.trainer.callbacks} override_types = model_callback_types.intersection(trainer_callback_types) if override_types: rank_zero_info( @@ -163,11 +161,11 @@ def _attach_model_callbacks(model: 'pl.LightningModule', trainer) -> None: f" {', '.join(sorted(t.__name__ for t in override_types))}" ) # remove all callbacks with a type that occurs in model callbacks - all_callbacks = [c for c in trainer.callbacks if type(c) not in override_types] + all_callbacks = [c for c in self.trainer.callbacks if type(c) not in override_types] all_callbacks.extend(model_callbacks) all_callbacks = CallbackConnector._reorder_callbacks(all_callbacks) # TODO: connectors refactor: move callbacks list to connector and do not write Trainer state - trainer.callbacks = all_callbacks + self.trainer.callbacks = all_callbacks @staticmethod def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 115073b7733c9..7ac1118c04acc 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -67,13 +67,13 @@ def get_profiled_train_dataloader(self, train_dataloader): ) return profiled_dl - def prepare_data(self, model): + def prepare_data(self) -> None: # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 if self.can_prepare_data(): if self.trainer.datamodule is not None: self.trainer.datamodule.prepare_data() - model.prepare_data() + self.trainer.call_hook('prepare_data') self.trainer._is_data_prepared = True def can_prepare_data(self): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 3604574fd1e81..579107d80398e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -77,12 +77,15 @@ class FxValidator: training_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), validation_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), test_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), - on_before_batch_transfer=None, - transfer_batch_to_device=None, - on_after_batch_transfer=None, - backward=None, - optimizer_step=None, - # TODO(@carmocca): some {step,epoch}_{start,end} are missing + configure_optimizers=None, + on_train_dataloader=None, + train_dataloader=None, + val_dataloader=None, + prepare_data=None, + configure_callbacks=None, + test_dataloader=None, + on_validation_model_eval=None, + on_test_model_eval=None, ) @classmethod @@ -95,7 +98,7 @@ def check_logging(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None: ) allowed = cls.functions[fx_name] if allowed is None: - raise MisconfigurationException(f"{fx_name} function doesn't support logging using `self.log()`") + raise MisconfigurationException(f"`{fx_name}` function doesn't support logging using `self.log()`") m = "You can't `self.log({}={})` inside `{}`, must be one of {}" if on_step not in allowed["on_step"]: diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 19280e4b05930..a63b71c94eb3d 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -235,7 +235,7 @@ def reset_train_dataloader(self, model: 'pl.LightningModule') -> None: Args: model: The current `LightningModule` """ - self.train_dataloader = self.request_dataloader(model, "train") + self.train_dataloader = self.request_dataloader("train") if self.overfit_batches > 0: if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler): @@ -332,7 +332,7 @@ def _reset_eval_dataloader( """ # always get the loaders first so we can count how many there are loader_name = f'{mode}_dataloader' - dataloaders = self.request_dataloader(model, mode) + dataloaders = self.request_dataloader(mode) if not isinstance(dataloaders, list): dataloaders = [dataloaders] @@ -341,7 +341,7 @@ def _reset_eval_dataloader( # duplicate it the numb of times needed to match the train loaders if self.overfit_batches > 0: num_loaders = len(dataloaders) - train_dataloader = self.request_dataloader(model, 'train') + train_dataloader = self.request_dataloader('train') dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)] self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders) @@ -463,14 +463,13 @@ def reset_train_val_dataloaders(self, model) -> None: if self.val_dataloaders is None: self.reset_val_dataloader(model) - def request_dataloader(self, model: 'pl.LightningModule', stage: str) -> Union[DataLoader, List[DataLoader]]: + def request_dataloader(self, stage: str) -> Union[DataLoader, List[DataLoader]]: """Handles downloading data in the GPU or TPU case. Returns: The dataloader """ - self.call_hook(f"on_{stage}_dataloader") - dataloader = getattr(model, f'{stage}_dataloader')() + dataloader = self.call_hook(f"{stage}_dataloader") if isinstance(dataloader, tuple): dataloader = list(dataloader) self.accelerator.barrier('get_dataloaders') diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 80ec5857de287..5172455d64cdb 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -31,7 +31,7 @@ class TrainerOptimizersMixin(ABC): def init_optimizers(self, model: 'pl.LightningModule') -> Tuple[List, List, List]: self._lightning_optimizers = None - optim_conf = model.configure_optimizers() + optim_conf = self.call_hook("configure_optimizers") if optim_conf is None: rank_zero_warn( '`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer', diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f6342b5e8e458..8be2c217ef1f1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -842,8 +842,8 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, self.callback_connector.attach_model_logging_functions(model) # hook - self.data_connector.prepare_data(model) - self.callback_connector._attach_model_callbacks(model, self) + self.data_connector.prepare_data() + self.callback_connector._attach_model_callbacks() # ---------------------------- # SET UP TRAINING @@ -995,20 +995,14 @@ def _pre_training_routine(self): # -------------------------- # Pre-train # -------------------------- - # on pretrain routine start - ref_model = self.lightning_module - - self.on_pretrain_routine_start() - ref_model.on_pretrain_routine_start() + self.call_hook('on_pretrain_routine_start') # print model summary if self.is_global_zero and self.weights_summary is not None and not self.testing: max_depth = ModelSummary.MODES[self.weights_summary] - ref_model.summarize(max_depth=max_depth) + self.lightning_module.summarize(max_depth=max_depth) - # on pretrain routine end - self.on_pretrain_routine_end() - ref_model.on_pretrain_routine_end() + self.call_hook('on_pretrain_routine_end') def _run_train(self) -> None: self._pre_training_routine() @@ -1092,8 +1086,7 @@ def _run_sanity_check(self, ref_model): stage = self.state.stage self.sanity_checking = True - # hook and callback - self.on_sanity_check_start() + self.call_hook('on_sanity_check_start') # reload dataloaders self._evaluation_loop.reload_evaluation_dataloaders() @@ -1102,7 +1095,7 @@ def _run_sanity_check(self, ref_model): with torch.no_grad(): self._evaluation_loop.run() - self.on_sanity_check_end() + self.call_hook('on_sanity_check_end') # reset validation metrics self.logger_connector.reset() @@ -1155,8 +1148,7 @@ def _call_setup_hook(self, model: 'pl.LightningModule') -> None: if self.datamodule is not None: self.datamodule.setup(stage=fn) - self.setup(model, stage=fn) - model.setup(stage=fn) + self.call_hook('setup', stage=fn) self.accelerator.barrier("post_setup") @@ -1168,8 +1160,7 @@ def _call_configure_sharded_model(self, model: 'pl.LightningModule') -> None: model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook: with self.accelerator.model_sharded_context(): - model.configure_sharded_model() - self.configure_sharded_model(model) + self.call_hook('configure_sharded_model') model.call_configure_sharded_model_hook = True self.accelerator.call_configure_sharded_model_hook = False @@ -1179,8 +1170,7 @@ def _call_teardown_hook(self, model: 'pl.LightningModule') -> None: if self.datamodule is not None: self.datamodule.teardown(stage=fn) self.profiler.teardown(stage=fn) - self.teardown(stage=fn) - model.teardown(stage=fn) + self.call_hook('teardown', stage=fn) model._current_fx_name = None model._current_dataloader_idx = None @@ -1200,9 +1190,11 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: with self.profiler.profile(hook_name): # first call trainer hook - if hasattr(self, hook_name): + if hook_name not in ("setup", ) and hasattr(self, hook_name): trainer_hook = getattr(self, hook_name) - trainer_hook(*args, **kwargs) + if trainer_hook is not None: + # `train_dataloader` is a function for the `LightningModule` but an attribute for the `Trainer` + trainer_hook(*args, **kwargs) # next call hook in lightningModule output = None @@ -1212,7 +1204,7 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: output = hook_fx(*args, **kwargs) # call the accelerator hook - if hasattr(self.accelerator, hook_name): + if hook_name not in ("setup", "teardown") and hasattr(self.accelerator, hook_name): accelerator_hook = getattr(self.accelerator, hook_name) accelerator_output = accelerator_hook(*args, **kwargs) # Rely on the accelerator output if lightningModule hook returns nothing diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 501482d77a240..c23f9f9d499f4 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -41,7 +41,7 @@ def test_checkpoint_callbacks_are_last(tmpdir): model.configure_callbacks.return_value = [] trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2]) cb_connector = CallbackConnector(trainer) - cb_connector._attach_model_callbacks(model, trainer) + cb_connector._attach_model_callbacks() assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2] # with model-specific callbacks that substitute ones in Trainer @@ -49,7 +49,7 @@ def test_checkpoint_callbacks_are_last(tmpdir): model.configure_callbacks.return_value = [checkpoint1, early_stopping, checkpoint2] trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)]) cb_connector = CallbackConnector(trainer) - cb_connector._attach_model_callbacks(model, trainer) + cb_connector._attach_model_callbacks() assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, checkpoint1, checkpoint2] @@ -96,7 +96,7 @@ def assert_composition(trainer_callbacks, model_callbacks, expected): model.configure_callbacks.return_value = model_callbacks trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks) cb_connector = CallbackConnector(trainer) - cb_connector._attach_model_callbacks(model, trainer) + cb_connector._attach_model_callbacks() assert trainer.callbacks == expected early_stopping = EarlyStopping() @@ -149,6 +149,6 @@ def test_attach_model_callbacks_override_info(caplog): trainer = Trainer(checkpoint_callback=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()]) cb_connector = CallbackConnector(trainer) with caplog.at_level(logging.INFO): - cb_connector._attach_model_callbacks(model, trainer) + cb_connector._attach_model_callbacks() assert "existing callbacks passed to Trainer: EarlyStopping, LearningRateMonitor" in caplog.text diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 27598b40fbd31..86a33005c967d 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from unittest import mock import pytest @@ -26,9 +27,10 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf +from tests.models.test_hooks import get_members -def test_fx_validator(tmpdir): +def test_fx_validator(): funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')]) callbacks_func = [ @@ -145,6 +147,114 @@ def test_fx_validator(tmpdir): validator.check_logging("foo", False, False) +class HookedCallback(Callback): + + def __init__(self, not_supported): + + def call(hook, *args, **_): + # `on_init_{start,end}` do not have the `LightningModule` available + lightning_module = [m for m in args if isinstance(m, LightningModule)] + if not lightning_module: + return + lightning_module = lightning_module[0] + + if hook in not_supported: + with pytest.raises(MisconfigurationException, match='self.log'): + lightning_module.log('anything', 1) + else: + lightning_module.log(hook, 1) + + for h in get_members(Callback): + setattr(self, h, partial(call, h)) + + +class HookedModel(BoringModel): + + def __init__(self, not_supported): + super().__init__() + pl_module_hooks = get_members(LightningModule) + pl_module_hooks.difference_update({ + 'log', + 'log_dict', + # the following are problematic as they do have `self._current_fx_name` defined some times but + # not others depending on where they were called. So we cannot reliably `self.log` in them + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'get_progress_bar_dict', + }) + # remove `nn.Module` hooks + module_hooks = get_members(torch.nn.Module) + pl_module_hooks.difference_update(module_hooks) + + def call(hook, fn, *args, **kwargs): + out = fn(*args, **kwargs) + + if hook in not_supported: + with pytest.raises(MisconfigurationException, match=''): + self.log('anything', 1) + else: + self.log(hook, 1) + return out + + for h in pl_module_hooks: + attr = getattr(self, h) + setattr(self, h, partial(call, h, attr)) + + +def test_fx_validator_integration(tmpdir): + """Tries to log inside all `LightningModule` and `Callback` hooks to check any expected errors""" + not_supported = [ + 'on_before_accelerator_backend_setup', + 'setup', + 'configure_sharded_model', + 'on_configure_sharded_model', + 'configure_optimizers', + 'on_fit_start', + 'on_pretrain_routine_start', + 'on_pretrain_routine_end', + 'on_train_dataloader', + 'train_dataloader', + 'val_dataloader', + 'on_validation_end', + 'on_train_end', + 'on_fit_end', + 'teardown', + 'on_sanity_check_start', + 'on_sanity_check_end', + 'prepare_data', + 'configure_callbacks', + 'test_dataloader', + 'on_validation_model_eval', + 'on_test_model_eval', + 'on_test_end', + 'predict_dataloader', + 'on_predict_model_eval', + 'on_predict_start', + 'on_predict_epoch_start', + 'on_predict_batch_start', + 'predict_step', + 'on_predict_batch_end', + 'on_predict_epoch_end', + 'on_predict_end', + 'summarize', + ] + model = HookedModel(not_supported) + callback = HookedCallback(not_supported) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + limit_predict_batches=1, + callbacks=callback + ) + trainer.fit(model) + trainer.test(model) + trainer.predict(model) + + @RunIf(min_gpus=2) def test_epoch_results_cache_dp(tmpdir): From 7106af00bd94a04d11ea7d525f4a6c776345bc79 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 20 Jul 2021 22:52:55 +0200 Subject: [PATCH 02/29] Fix breakage --- .../trainer/connectors/logger_connector/fx_validator.py | 4 +++- pytorch_lightning/trainer/data_loading.py | 4 +++- tests/trainer/logging_/test_logger_connector.py | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 579107d80398e..a9342be2b3612 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -80,10 +80,12 @@ class FxValidator: configure_optimizers=None, on_train_dataloader=None, train_dataloader=None, + on_val_dataloader=None, val_dataloader=None, + on_test_dataloader=None, + test_dataloader=None, prepare_data=None, configure_callbacks=None, - test_dataloader=None, on_validation_model_eval=None, on_test_model_eval=None, ) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index a63b71c94eb3d..6bfa80bc92d68 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -469,7 +469,9 @@ def request_dataloader(self, stage: str) -> Union[DataLoader, List[DataLoader]]: Returns: The dataloader """ - dataloader = self.call_hook(f"{stage}_dataloader") + hook = f"{stage}_dataloader" + self.call_hook("on_" + hook) + dataloader = self.call_hook(hook) if isinstance(dataloader, tuple): dataloader = list(dataloader) self.accelerator.barrier('get_dataloaders') diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 86a33005c967d..ce64bf7fdf6bf 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -215,6 +215,7 @@ def test_fx_validator_integration(tmpdir): 'on_pretrain_routine_end', 'on_train_dataloader', 'train_dataloader', + 'on_val_dataloader', 'val_dataloader', 'on_validation_end', 'on_train_end', @@ -224,10 +225,12 @@ def test_fx_validator_integration(tmpdir): 'on_sanity_check_end', 'prepare_data', 'configure_callbacks', + 'on_test_dataloader', 'test_dataloader', 'on_validation_model_eval', 'on_test_model_eval', 'on_test_end', + 'on_predict_dataloader', 'predict_dataloader', 'on_predict_model_eval', 'on_predict_start', From a370bac930b33f8aef0a31d335c4a36522c67f26 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 20 Jul 2021 23:13:24 +0200 Subject: [PATCH 03/29] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71d428b119430..9c630b58b711e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -180,6 +180,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added experimental support for loop specialization ([#8226](https://github.com/PyTorchLightning/pytorch-lightning/pull/8226)) +- Improve coverage of `self.log`-ing in any `LightningModule` or `Callback` hook ([#8498](https://github.com/PyTorchLightning/pytorch-lightning/pull/8498)) + + - Added support for `devices` flag to Trainer ([#8440](https://github.com/PyTorchLightning/pytorch-lightning/pull/8440)) From 46116ee1f187765b445de7264b63649faef406c7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 20 Jul 2021 23:42:45 +0200 Subject: [PATCH 04/29] Check the exact matches --- pytorch_lightning/core/lightning.py | 4 +- .../logger_connector/fx_validator.py | 4 +- .../trainer/logging_/test_logger_connector.py | 94 ++++++++++--------- 3 files changed, 55 insertions(+), 47 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index df8ba49ce2452..79df2d7a7da28 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -404,8 +404,8 @@ def log( caller = inspect.stack()[1][3] raise MisconfigurationException( f'You are trying to `self.log()` inside `{caller}` but the loop `ResultCollection` is not registered' - ' yet. This is most likely because you are trying to log in a `predict` hook, but it does' - ' not support logging.' + ' yet. This is most likely because you are trying to log in a `predict` hook,' + " but it doesn't support logging." ) if self.trainer.lightning_module is None: # this is to avoid `lightning_module.log` in callback hooks which have the lightning module available as a diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index a9342be2b3612..3491b8ea78ef3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -95,12 +95,12 @@ def check_logging(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None: """Check if the given function name is allowed to log""" if fx_name not in cls.functions: raise RuntimeError( - f'You are trying to `self.log()` inside `{fx_name}` but it is not implemented.' + f'Logging inside `{fx_name}` is not implemented.' ' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`' ) allowed = cls.functions[fx_name] if allowed is None: - raise MisconfigurationException(f"`{fx_name}` function doesn't support logging using `self.log()`") + raise MisconfigurationException(f"You can't `self.log()` inside `{fx_name}`") m = "You can't `self.log({}={})` inside `{}`, must be one of {}" if on_step not in allowed["on_step"]: diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index ce64bf7fdf6bf..ed7dd86c1fc35 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -136,14 +136,14 @@ def test_fx_validator(): if allowed: validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch) if not is_start and is_stage: - with pytest.raises(MisconfigurationException, match="You can't"): + with pytest.raises(MisconfigurationException, match="must be one of"): validator.check_logging(fx_name=func_name, on_step=True, on_epoch=on_epoch) else: assert func_name in not_supported - with pytest.raises(MisconfigurationException, match="function doesn't support"): + with pytest.raises(MisconfigurationException, match="You can't"): validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch) - with pytest.raises(RuntimeError, match="`foo` but it is not implemented"): + with pytest.raises(RuntimeError, match="Logging inside `foo` is not implemented"): validator.check_logging("foo", False, False) @@ -159,7 +159,7 @@ def call(hook, *args, **_): lightning_module = lightning_module[0] if hook in not_supported: - with pytest.raises(MisconfigurationException, match='self.log'): + with pytest.raises(MisconfigurationException, match=not_supported[hook]): lightning_module.log('anything', 1) else: lightning_module.log(hook, 1) @@ -191,7 +191,7 @@ def call(hook, fn, *args, **kwargs): out = fn(*args, **kwargs) if hook in not_supported: - with pytest.raises(MisconfigurationException, match=''): + with pytest.raises(MisconfigurationException, match=not_supported[hook]): self.log('anything', 1) else: self.log(hook, 1) @@ -204,44 +204,30 @@ def call(hook, fn, *args, **kwargs): def test_fx_validator_integration(tmpdir): """Tries to log inside all `LightningModule` and `Callback` hooks to check any expected errors""" - not_supported = [ - 'on_before_accelerator_backend_setup', - 'setup', - 'configure_sharded_model', - 'on_configure_sharded_model', - 'configure_optimizers', - 'on_fit_start', - 'on_pretrain_routine_start', - 'on_pretrain_routine_end', - 'on_train_dataloader', - 'train_dataloader', - 'on_val_dataloader', - 'val_dataloader', - 'on_validation_end', - 'on_train_end', - 'on_fit_end', - 'teardown', - 'on_sanity_check_start', - 'on_sanity_check_end', - 'prepare_data', - 'configure_callbacks', - 'on_test_dataloader', - 'test_dataloader', - 'on_validation_model_eval', - 'on_test_model_eval', - 'on_test_end', - 'on_predict_dataloader', - 'predict_dataloader', - 'on_predict_model_eval', - 'on_predict_start', - 'on_predict_epoch_start', - 'on_predict_batch_start', - 'predict_step', - 'on_predict_batch_end', - 'on_predict_epoch_end', - 'on_predict_end', - 'summarize', - ] + not_supported = { + 'on_before_accelerator_backend_setup': 'LightningModule` is not registered yet', + 'setup': "You can't", + 'configure_sharded_model': "You can't", + 'on_configure_sharded_model': "You can't", + 'configure_optimizers': "You can't", + 'on_fit_start': "You can't", + 'on_pretrain_routine_start': "You can't", + 'on_pretrain_routine_end': "You can't", + 'on_train_dataloader': "You can't", + 'train_dataloader': "You can't", + 'on_val_dataloader': "You can't", + 'val_dataloader': "You can't", + 'on_validation_end': "You can't", + 'on_train_end': "You can't", + 'on_fit_end': "You can't", + 'teardown': "You can't", + 'on_sanity_check_start': "You can't", + 'on_sanity_check_end': "You can't", + 'prepare_data': "You can't", + 'configure_callbacks': "You can't", + 'on_validation_model_eval': "You can't", + 'summarize': 'not managed by the `Trainer', + } model = HookedModel(not_supported) callback = HookedCallback(not_supported) trainer = Trainer( @@ -254,7 +240,29 @@ def test_fx_validator_integration(tmpdir): callbacks=callback ) trainer.fit(model) + + not_supported.update({ + 'on_before_accelerator_backend_setup': "You can't", # `lightning_module` ref is now present from the `fit` call + 'on_test_dataloader': "You can't", + 'test_dataloader': "You can't", + 'on_test_model_eval': "You can't", + 'on_test_end': "You can't", + }) trainer.test(model) + + not_supported.update({k: "ResultCollection` is not registered yet" for k in not_supported}) + not_supported.update({ + 'on_predict_dataloader': "ResultCollection` is not registered yet", + 'predict_dataloader': "ResultCollection` is not registered yet", + 'on_predict_model_eval': "ResultCollection` is not registered yet", + 'on_predict_start': "ResultCollection` is not registered yet", + 'on_predict_epoch_start': "ResultCollection` is not registered yet", + 'on_predict_batch_start': "ResultCollection` is not registered yet", + 'predict_step': "ResultCollection` is not registered yet", + 'on_predict_batch_end': "ResultCollection` is not registered yet", + 'on_predict_epoch_end': "ResultCollection` is not registered yet", + 'on_predict_end': "ResultCollection` is not registered yet", + }) trainer.predict(model) From c5e26674426323d63ef4cb6186e729b2f8d54020 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 21 Jul 2021 00:07:11 +0200 Subject: [PATCH 05/29] Better args --- tests/trainer/logging_/test_logger_connector.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index ed7dd86c1fc35..75438d3957f74 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -151,12 +151,12 @@ class HookedCallback(Callback): def __init__(self, not_supported): - def call(hook, *args, **_): - # `on_init_{start,end}` do not have the `LightningModule` available - lightning_module = [m for m in args if isinstance(m, LightningModule)] - if not lightning_module: + def call(hook, trainer, model=None, *_, **__): + lightning_module = trainer.lightning_module or model + if lightning_module is None: + # `on_init_{start,end}` do not have the `LightningModule` available + assert hook in ('on_init_start', 'on_init_end') return - lightning_module = lightning_module[0] if hook in not_supported: with pytest.raises(MisconfigurationException, match=not_supported[hook]): @@ -237,7 +237,7 @@ def test_fx_validator_integration(tmpdir): limit_val_batches=1, limit_test_batches=1, limit_predict_batches=1, - callbacks=callback + callbacks=callback, ) trainer.fit(model) @@ -248,7 +248,7 @@ def test_fx_validator_integration(tmpdir): 'on_test_model_eval': "You can't", 'on_test_end': "You can't", }) - trainer.test(model) + trainer.test(model, verbose=False) not_supported.update({k: "ResultCollection` is not registered yet" for k in not_supported}) not_supported.update({ From 07503a1f48bdc02770293f34efa7c0460be7b8a8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 21 Jul 2021 00:44:14 +0200 Subject: [PATCH 06/29] Docs and FIXME --- pytorch_lightning/trainer/connectors/callback_connector.py | 5 ----- pytorch_lightning/trainer/trainer.py | 1 + 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index d8e866f8b1aca..d090cb22943a1 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -142,11 +142,6 @@ def _attach_model_callbacks(self) -> None: callbacks already present in the trainer callbacks list, it will replace them. In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks will be pushed to the end of the list, ensuring they run last. - - Args: - model: A model which may or may not define new callbacks in - :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_callbacks`. - trainer: The trainer on which the callbacks get attached/merged. """ model_callbacks = self.trainer.call_hook('configure_callbacks') if not model_callbacks: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8be2c217ef1f1..98aa229f223f0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -843,6 +843,7 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, # hook self.data_connector.prepare_data() + # FIXME: need the `lightning_module` reference to be set self.callback_connector._attach_model_callbacks() # ---------------------------- From fe1585d7684e9e0cc6a38cbf0400b45fadb43ded Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 21 Jul 2021 13:23:13 +0200 Subject: [PATCH 07/29] Remove extra assertion --- pytorch_lightning/core/lightning.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 79df2d7a7da28..79fe7881186be 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -420,7 +420,6 @@ def log( raise MisconfigurationException( f'You are trying to `self.log()` inside `{caller}` but it is not managed by the `Trainer` control flow' ) - assert self._current_fx_name is not None FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) # make sure user doesn't introduce logic for multi-dataloaders From a02cfacaeaab1b59909dda72afce69dced20a3f8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 21 Jul 2021 14:09:58 +0200 Subject: [PATCH 08/29] Resolve model connection --- pytorch_lightning/trainer/callback_hook.py | 4 +-- pytorch_lightning/trainer/trainer.py | 29 ++++++++++++---------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 4b2a7d7b99565..6cc081f3853f6 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -36,10 +36,10 @@ class TrainerCallbackHookMixin(ABC): callbacks: List[Callback] = [] lightning_module: 'pl.LightningModule' - def on_before_accelerator_backend_setup(self, model: 'pl.LightningModule') -> None: + def on_before_accelerator_backend_setup(self) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: - callback.on_before_accelerator_backend_setup(self, model) + callback.on_before_accelerator_backend_setup(self, self.lightning_module) def configure_sharded_model(self) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 98aa229f223f0..02fba17228723 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -841,18 +841,19 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, # attach model log function to callback self.callback_connector.attach_model_logging_functions(model) + # attach model to the training type plugin + self.accelerator.connect(model) + # hook self.data_connector.prepare_data() - # FIXME: need the `lightning_module` reference to be set self.callback_connector._attach_model_callbacks() # ---------------------------- # SET UP TRAINING # ---------------------------- - self.call_hook("on_before_accelerator_backend_setup", model) - self.accelerator.connect(model) + self.call_hook("on_before_accelerator_backend_setup") self.accelerator.setup_environment() - self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment + self._call_setup_hook() # allow user to setup lightning_module in accelerator environment # restore modules after setup self.checkpoint_connector.restore_datamodule() @@ -860,8 +861,9 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, # restore callback states self.checkpoint_connector.restore_callbacks() - self._call_configure_sharded_model(model) # allow user to setup in model sharded environment - self.accelerator.setup(self, model) # note: this sets up self.lightning_module + self._call_configure_sharded_model() # allow user to setup in model sharded environment + # TODO: `model` can be removed as an argument of `TrainingTypePlugin.setup` as it is already connected + self.accelerator.setup(self, model) # ---------------------------- # INSPECT THE CORE LOOPS @@ -918,7 +920,7 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, self.call_hook('on_fit_end') # teardown - self._call_teardown_hook(model) + self._call_teardown_hook() if self.state.status != TrainerStatus.INTERRUPTED: self.state.status = TrainerStatus.FINISHED @@ -1142,7 +1144,7 @@ def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]: self.checkpoint_connector.restore_model_weights(ckpt_path) return ckpt_path - def _call_setup_hook(self, model: 'pl.LightningModule') -> None: + def _call_setup_hook(self) -> None: fn = self.state.fn._setup_fn self.accelerator.barrier("pre_setup") @@ -1153,11 +1155,12 @@ def _call_setup_hook(self, model: 'pl.LightningModule') -> None: self.accelerator.barrier("post_setup") - def _call_configure_sharded_model(self, model: 'pl.LightningModule') -> None: + def _call_configure_sharded_model(self) -> None: # Call configure sharded model hook if accelerator requests. In some cases # we will not call the hook; the hook has initialized the sharded model for example. # used on the model if the user re-create a trainer with resume_from_checkpoint + model = self.lightning_module model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook: with self.accelerator.model_sharded_context(): @@ -1165,7 +1168,7 @@ def _call_configure_sharded_model(self, model: 'pl.LightningModule') -> None: model.call_configure_sharded_model_hook = True self.accelerator.call_configure_sharded_model_hook = False - def _call_teardown_hook(self, model: 'pl.LightningModule') -> None: + def _call_teardown_hook(self) -> None: fn = self.state.fn._setup_fn if self.datamodule is not None: @@ -1173,10 +1176,10 @@ def _call_teardown_hook(self, model: 'pl.LightningModule') -> None: self.profiler.teardown(stage=fn) self.call_hook('teardown', stage=fn) - model._current_fx_name = None - model._current_dataloader_idx = None + self.lightning_module._current_fx_name = None + self.lightning_module._current_dataloader_idx = None # these could have become stale if metrics are defined in `setup` - model._metric_attributes = None + self.lightning_module._metric_attributes = None def call_hook(self, hook_name: str, *args, **kwargs) -> Any: # Note this implementation is copy/pasted into the TrainLoop class in TrainingEpochLoop._on_train_epoch_end_hook From f85505588698187d0426e66a54a18149e7eee90f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 21 Jul 2021 14:43:30 +0200 Subject: [PATCH 09/29] Fixes --- pytorch_lightning/trainer/callback_hook.py | 6 +++--- pytorch_lightning/trainer/trainer.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 6cc081f3853f6..c8cc96c398872 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -41,15 +41,15 @@ def on_before_accelerator_backend_setup(self) -> None: for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, self.lightning_module) - def configure_sharded_model(self) -> None: + def on_configure_sharded_model(self) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.on_configure_sharded_model(self, self.lightning_module) - def setup(self, model: 'pl.LightningModule', stage: Optional[str]) -> None: + def setup(self, stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: - callback.setup(self, model, stage=stage) + callback.setup(self, self.lightning_module, stage=stage) def teardown(self, stage: Optional[str] = None) -> None: """Called at the end of fit (train + validate), validate, test, or predict, or tune.""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 02fba17228723..7c97d90ce7f42 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1165,6 +1165,7 @@ def _call_configure_sharded_model(self) -> None: if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook: with self.accelerator.model_sharded_context(): self.call_hook('configure_sharded_model') + self.call_hook('on_configure_sharded_model') model.call_configure_sharded_model_hook = True self.accelerator.call_configure_sharded_model_hook = False @@ -1194,7 +1195,7 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: with self.profiler.profile(hook_name): # first call trainer hook - if hook_name not in ("setup", ) and hasattr(self, hook_name): + if hasattr(self, hook_name): trainer_hook = getattr(self, hook_name) if trainer_hook is not None: # `train_dataloader` is a function for the `LightningModule` but an attribute for the `Trainer` From b0788a1d85103cb6440652e6f902830855a78199 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 21 Jul 2021 14:51:51 +0200 Subject: [PATCH 10/29] Remove no lightning module check --- pytorch_lightning/core/lightning.py | 8 -------- tests/trainer/logging_/test_logger_connector.py | 2 +- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 79fe7881186be..ea6a2cee5d320 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -407,14 +407,6 @@ def log( ' yet. This is most likely because you are trying to log in a `predict` hook,' " but it doesn't support logging." ) - if self.trainer.lightning_module is None: - # this is to avoid `lightning_module.log` in callback hooks which have the lightning module available as a - # parameter but the reference in the trainer has not been set yet - caller = inspect.stack()[1][3] - raise MisconfigurationException( - f'You are trying to `self.log()` inside `{caller}` but the' - ' `LightningModule` is not registered yet for the trainer.' - ) if self._current_fx_name is None: caller = inspect.stack()[1][3] raise MisconfigurationException( diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 75438d3957f74..81de1f4e7fbd2 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -205,7 +205,7 @@ def call(hook, fn, *args, **kwargs): def test_fx_validator_integration(tmpdir): """Tries to log inside all `LightningModule` and `Callback` hooks to check any expected errors""" not_supported = { - 'on_before_accelerator_backend_setup': 'LightningModule` is not registered yet', + 'on_before_accelerator_backend_setup': "You can't", 'setup': "You can't", 'configure_sharded_model': "You can't", 'on_configure_sharded_model': "You can't", From 12658fe90eca5c9c9a55aabce63dce821ac013e9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 21 Jul 2021 15:14:43 +0200 Subject: [PATCH 11/29] Check callable --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7c97d90ce7f42..4480cf6bae4b8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1197,7 +1197,7 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: # first call trainer hook if hasattr(self, hook_name): trainer_hook = getattr(self, hook_name) - if trainer_hook is not None: + if callable(trainer_hook): # `train_dataloader` is a function for the `LightningModule` but an attribute for the `Trainer` trainer_hook(*args, **kwargs) From f432fca894047c763659dbe43ab162d4fbe9dca9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 21 Jul 2021 15:15:02 +0200 Subject: [PATCH 12/29] Fix mock test --- tests/trainer/connectors/test_callback_connector.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index c23f9f9d499f4..50d2814b08085 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -44,10 +44,9 @@ def test_checkpoint_callbacks_are_last(tmpdir): cb_connector._attach_model_callbacks() assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2] - # with model-specific callbacks that substitute ones in Trainer - model = Mock() - model.configure_callbacks.return_value = [checkpoint1, early_stopping, checkpoint2] trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)]) + # with model-specific callbacks that substitute ones in Trainer + trainer.call_hook = Mock(return_value=[checkpoint1, early_stopping, checkpoint2]) cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, checkpoint1, checkpoint2] @@ -92,9 +91,8 @@ def test_attach_model_callbacks(): """ Test that the callbacks defined in the model and through Trainer get merged correctly. """ def assert_composition(trainer_callbacks, model_callbacks, expected): - model = Mock() - model.configure_callbacks.return_value = model_callbacks trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks) + trainer.call_hook = Mock(return_value=model_callbacks) cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == expected @@ -144,9 +142,9 @@ def assert_composition(trainer_callbacks, model_callbacks, expected): def test_attach_model_callbacks_override_info(caplog): """ Test that the logs contain the info about overriding callbacks returned by configure_callbacks. """ - model = Mock() - model.configure_callbacks.return_value = [LearningRateMonitor(), EarlyStopping()] trainer = Trainer(checkpoint_callback=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()]) + trainer.call_hook = Mock(return_value=[LearningRateMonitor(), EarlyStopping()]) + cb_connector = CallbackConnector(trainer) with caplog.at_level(logging.INFO): cb_connector._attach_model_callbacks() From b066f6fea7e5174bcc93f9f1ff1e98037d9cd04b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 21 Jul 2021 15:47:21 +0200 Subject: [PATCH 13/29] Add pl_module to call_hook --- .../loops/dataloader/evaluation_loop.py | 5 +- pytorch_lightning/trainer/data_loading.py | 61 +++++++++++-------- pytorch_lightning/trainer/optimizers.py | 7 ++- pytorch_lightning/trainer/trainer.py | 19 +++--- 4 files changed, 50 insertions(+), 42 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 94867750c025e..bb76bfe5fa685 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -169,11 +169,10 @@ def get_max_batches(self) -> List[Union[int, float]]: def reload_evaluation_dataloaders(self) -> None: """Reloads dataloaders if necessary""" - model = self.trainer.lightning_module if self.trainer.testing: - self.trainer.reset_test_dataloader(model) + self.trainer.reset_test_dataloader() elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch: - self.trainer.reset_val_dataloader(model) + self.trainer.reset_val_dataloader() def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: """Runs ``on_{validation/test}_start`` hooks""" diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 6bfa80bc92d68..35f7fdb785289 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -228,14 +228,14 @@ def _get_distributed_sampler( sampler = cls(dataloader.dataset, **kwargs) return sampler - def reset_train_dataloader(self, model: 'pl.LightningModule') -> None: + def reset_train_dataloader(self, model: Optional['pl.LightningModule'] = None) -> None: """Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). Args: - model: The current `LightningModule` + model: The `LightningModule` if calling this outside of the trainer scope. """ - self.train_dataloader = self.request_dataloader("train") + self.train_dataloader = self.request_dataloader("train", model=model) if self.overfit_batches > 0: if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler): @@ -318,21 +318,21 @@ def reset_train_dataloader(self, model: 'pl.LightningModule') -> None: def _reset_eval_dataloader( self, - model: 'pl.LightningModule', mode: str, + model: Optional['pl.LightningModule'] = None ) -> Tuple[List[Union[int, float]], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. Args: - model: The current `LightningModule` mode: Either `'val'`, `'test'` or `'predict'` + model: The `LightningModule` if calling this outside of the trainer scope. Returns: Tuple (num_batches, dataloaders) """ # always get the loaders first so we can count how many there are loader_name = f'{mode}_dataloader' - dataloaders = self.request_dataloader(mode) + dataloaders = self.request_dataloader(mode, model=model) if not isinstance(dataloaders, list): dataloaders = [dataloaders] @@ -418,60 +418,67 @@ def _reset_eval_dataloader( return loader_num_batches, dataloaders - def reset_val_dataloader(self, model: 'pl.LightningModule') -> None: + def reset_val_dataloader(self, model: Optional['pl.LightningModule'] = None) -> None: """Resets the validation dataloader and determines the number of batches. Args: - model: The current `LightningModule` + model: The `LightningModule` if called outside of the trainer scope. """ - has_loader = is_overridden('val_dataloader', model) - has_step = is_overridden('validation_step', model) + pl_module = self.lightning_module or model + has_loader = is_overridden('val_dataloader', pl_module) + has_step = is_overridden('validation_step', pl_module) if has_loader and has_step: - self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val') + self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader('val', model=pl_module) - def reset_test_dataloader(self, model) -> None: + def reset_test_dataloader(self, model: Optional['pl.LightningModule'] = None) -> None: """Resets the test dataloader and determines the number of batches. Args: - model: The current `LightningModule` + model: The `LightningModule` if called outside of the trainer scope. """ - has_loader = is_overridden('test_dataloader', model) - has_step = is_overridden('test_step', model) + pl_module = self.lightning_module or model + has_loader = is_overridden('test_dataloader', pl_module) + has_step = is_overridden('test_step', pl_module) if has_loader and has_step: - self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader(model, 'test') + self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader('test', model=pl_module) - def reset_predict_dataloader(self, model) -> None: + def reset_predict_dataloader(self, model: Optional['pl.LightningModule'] = None) -> None: """Resets the predict dataloader and determines the number of batches. Args: - model: The current `LightningModule` + model: The `LightningModule` if called outside of the trainer scope. """ - has_loader = is_overridden('predict_dataloader', model) + pl_module = self.lightning_module or model + has_loader = is_overridden('predict_dataloader', pl_module) if has_loader: - self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(model, 'predict') + self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader('predict', model=pl_module) - def reset_train_val_dataloaders(self, model) -> None: + def reset_train_val_dataloaders(self, model: Optional['pl.LightningModule'] = None) -> None: """ Resets train and val dataloaders if none are attached to the trainer. The val dataloader must be initialized before training loop starts, as the training loop inspects the val dataloader to determine whether to run the evaluation loop. + + Args: + model: The `LightningModule` if called outside of the trainer scope. """ if self.train_dataloader is None: - self.reset_train_dataloader(model) - + self.reset_train_dataloader(model=model) if self.val_dataloaders is None: - self.reset_val_dataloader(model) + self.reset_val_dataloader(model=model) - def request_dataloader(self, stage: str) -> Union[DataLoader, List[DataLoader]]: + def request_dataloader(self, + stage: str, + model: Optional['pl.LightningModule'] = None) -> Union[DataLoader, List[DataLoader]]: """Handles downloading data in the GPU or TPU case. Returns: The dataloader """ hook = f"{stage}_dataloader" - self.call_hook("on_" + hook) - dataloader = self.call_hook(hook) + self.call_hook("on_" + hook, pl_module=model) + dataloader = self.call_hook(hook, pl_module=model) if isinstance(dataloader, tuple): dataloader = list(dataloader) self.accelerator.barrier('get_dataloaders') diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 5172455d64cdb..ea96d53cf7309 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -29,9 +29,10 @@ class TrainerOptimizersMixin(ABC): _lightning_optimizers: Optional[List[LightningOptimizer]] - def init_optimizers(self, model: 'pl.LightningModule') -> Tuple[List, List, List]: + def init_optimizers(self, model: Optional['pl.LightningModule']) -> Tuple[List, List, List]: + pl_module = self.lightning_module or model self._lightning_optimizers = None - optim_conf = self.call_hook("configure_optimizers") + optim_conf = self.call_hook("configure_optimizers", pl_module=pl_module) if optim_conf is None: rank_zero_warn( '`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer', @@ -93,7 +94,7 @@ def init_optimizers(self, model: 'pl.LightningModule') -> Tuple[List, List, List ' * A list of the previously described dict format, with an optional "frequency" key (int)' ) - is_manual_optimization = not model.automatic_optimization + is_manual_optimization = not pl_module.automatic_optimization lr_schedulers = self.configure_schedulers(lr_schedulers, monitor, is_manual_optimization) _validate_scheduler_optimizer(optimizers, lr_schedulers) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4480cf6bae4b8..7d96978a9aaea 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1182,14 +1182,16 @@ def _call_teardown_hook(self) -> None: # these could have become stale if metrics are defined in `setup` self.lightning_module._metric_attributes = None - def call_hook(self, hook_name: str, *args, **kwargs) -> Any: + def call_hook(self, hook_name: str, *args, pl_module: Optional['pl.LightningModule'] = None, **kwargs) -> Any: # Note this implementation is copy/pasted into the TrainLoop class in TrainingEpochLoop._on_train_epoch_end_hook # This was done to manage the deprecation of the `outputs` argument to on_train_epoch_end # If making changes to this function, ensure that those changes are also made to # TrainingEpochLoop._on_train_epoch_end_hook - if self.lightning_module: - prev_fx_name = self.lightning_module._current_fx_name - self.lightning_module._current_fx_name = hook_name + pl_module = self.lightning_module or pl_module + + if pl_module: + prev_fx_name = pl_module._current_fx_name + pl_module._current_fx_name = hook_name # always profile hooks with self.profiler.profile(hook_name): @@ -1203,9 +1205,8 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: # next call hook in lightningModule output = None - model_ref = self.lightning_module - if is_overridden(hook_name, model_ref): - hook_fx = getattr(model_ref, hook_name) + if is_overridden(hook_name, pl_module): + hook_fx = getattr(pl_module, hook_name) output = hook_fx(*args, **kwargs) # call the accelerator hook @@ -1217,9 +1218,9 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: # todo: move this data parallel logic into the data parallel plugin output = accelerator_output if output is None else output - if self.lightning_module: + if pl_module: # restore current_fx when nested context - self.lightning_module._current_fx_name = prev_fx_name + pl_module._current_fx_name = prev_fx_name return output From 7149a6547ccce24cb01ac215c1a972d810ee2e42 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 21 Jul 2021 16:01:56 +0200 Subject: [PATCH 14/29] Do not check is overridden --- pytorch_lightning/trainer/trainer.py | 14 ++++++-------- tests/trainer/test_dataloaders.py | 4 +--- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7d96978a9aaea..ba0ca4b9d9e5f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1197,17 +1197,15 @@ def call_hook(self, hook_name: str, *args, pl_module: Optional['pl.LightningModu with self.profiler.profile(hook_name): # first call trainer hook - if hasattr(self, hook_name): - trainer_hook = getattr(self, hook_name) - if callable(trainer_hook): - # `train_dataloader` is a function for the `LightningModule` but an attribute for the `Trainer` - trainer_hook(*args, **kwargs) + callback_fx = getattr(self, hook_name, None) + if callable(callback_fx): + callback_fx(*args, **kwargs) # next call hook in lightningModule output = None - if is_overridden(hook_name, pl_module): - hook_fx = getattr(pl_module, hook_name) - output = hook_fx(*args, **kwargs) + model_fx = getattr(pl_module, hook_name, None) + if callable(model_fx): + output = model_fx(*args, **kwargs) # call the accelerator hook if hook_name not in ("setup", "teardown") and hasattr(self.accelerator, hook_name): diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d20fc7098d5a5..71ac13c5a3e3d 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1664,9 +1664,7 @@ def predict_dataloader(self): def test_request_dataloader(tmpdir): - """ - This test asserts dataloader can be modified and properly set to the trainer. - """ + """This test asserts dataloader can be modified and properly set to the trainer.""" class DataLoaderWrapper: From a735468502042fa9a4d6bc8550f8479f812ee77d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 21 Jul 2021 16:03:39 +0200 Subject: [PATCH 15/29] yapf --- pytorch_lightning/trainer/data_loading.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 35f7fdb785289..fca49b7ab3f7d 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -468,9 +468,11 @@ def reset_train_val_dataloaders(self, model: Optional['pl.LightningModule'] = No if self.val_dataloaders is None: self.reset_val_dataloader(model=model) - def request_dataloader(self, - stage: str, - model: Optional['pl.LightningModule'] = None) -> Union[DataLoader, List[DataLoader]]: + def request_dataloader( + self, + stage: str, + model: Optional['pl.LightningModule'] = None, + ) -> Union[DataLoader, List[DataLoader]]: """Handles downloading data in the GPU or TPU case. Returns: From 15bbcfc4d203a7306548a27d2998c00531b5d8e7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 23 Jul 2021 15:41:14 +0200 Subject: [PATCH 16/29] Fix test --- tests/trainer/test_trainer.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 683919560c7ae..a00348e079853 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1288,16 +1288,15 @@ class CurrentModel(BoringModel): def setup(self, stage): self.stage = stage - class TrainerSubclass(Trainer): + class CurrentCallback(Callback): - def setup(self, model, stage): + def setup(self, trainer, model, stage): assert model is not None self.stage = stage model = CurrentModel() - - # fit model - trainer = TrainerSubclass(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False) + callback = CurrentCallback() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False, callbacks=[callback]) if stage == "fit": trainer.fit(model) @@ -1306,8 +1305,8 @@ def setup(self, model, stage): else: trainer.test(model, ckpt_path=None) - assert trainer.stage == stage - assert trainer.lightning_module.stage == stage + assert callback.stage == stage + assert model.stage == stage @pytest.mark.parametrize( From 7a5c09367294fc9faa6e328892d29be97fa84abd Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 23 Jul 2021 15:48:20 +0200 Subject: [PATCH 17/29] Fix test --- pytorch_lightning/trainer/data_loading.py | 2 +- tests/trainer/test_trainer_tricks.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 88fa81fc403c1..1808acefcb3cd 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -340,7 +340,7 @@ def _reset_eval_dataloader( # duplicate it the numb of times needed to match the train loaders if self.overfit_batches > 0: num_loaders = len(dataloaders) - train_dataloader = self.request_dataloader('train') + train_dataloader = self.request_dataloader('train', model=model) dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)] self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 85aa7aa937740..ed59418a12de4 100644 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -122,7 +122,7 @@ def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # test overfit_batches as percent # ------------------------------------------------------ - loader_num_batches, dataloaders = Trainer(overfit_batches=0.11)._reset_eval_dataloader(model, split) + loader_num_batches, dataloaders = Trainer(overfit_batches=0.11)._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == num_train_samples # make sure we turned off shuffle for the user @@ -136,23 +136,23 @@ def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # test overfit_batches as int # ------------------------------------------------------ - loader_num_batches, dataloaders = Trainer(overfit_batches=1)._reset_eval_dataloader(model, split) + loader_num_batches, dataloaders = Trainer(overfit_batches=1)._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 1 - loader_num_batches, dataloaders = Trainer(overfit_batches=5)._reset_eval_dataloader(model, split) + loader_num_batches, dataloaders = Trainer(overfit_batches=5)._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 5 # ------------------------------------------------------ # test limit_xxx_batches as percent AND int # ------------------------------------------------------ if split == 'val': - loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(model, split) + loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == int(0.1 * len(val_loader)) - loader_num_batches, dataloaders = Trainer(limit_val_batches=10)._reset_eval_dataloader(model, split) + loader_num_batches, dataloaders = Trainer(limit_val_batches=10)._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 10 else: - loader_num_batches, dataloaders = Trainer(limit_test_batches=0.1)._reset_eval_dataloader(model, split) + loader_num_batches, dataloaders = Trainer(limit_test_batches=0.1)._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == int(0.1 * len(test_loader)) - loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(model, split) + loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 10 From be55e5bd9f714dc359829294492817dda123946c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 23 Jul 2021 15:49:29 +0200 Subject: [PATCH 18/29] Fix test --- tests/core/test_datamodules.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 6203e93e63e2f..b283348b6c041 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -34,8 +34,6 @@ @mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock) @mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock) def test_can_prepare_data(local_rank, node_rank): - - model = BoringModel() dm = BoringDataModule() trainer = Trainer() trainer.datamodule = dm @@ -51,7 +49,7 @@ def test_can_prepare_data(local_rank, node_rank): assert trainer.local_rank == 0 assert trainer.data_connector.can_prepare_data() - trainer.data_connector.prepare_data(model) + trainer.data_connector.prepare_data() assert dm.random_full is not None # local rank = 1 (False) @@ -61,7 +59,7 @@ def test_can_prepare_data(local_rank, node_rank): assert trainer.local_rank == 1 assert not trainer.data_connector.can_prepare_data() - trainer.data_connector.prepare_data(model) + trainer.data_connector.prepare_data() assert dm.random_full is None # prepare_data_per_node = False (prepare across all nodes) @@ -73,7 +71,7 @@ def test_can_prepare_data(local_rank, node_rank): local_rank.return_value = 0 assert trainer.data_connector.can_prepare_data() - trainer.data_connector.prepare_data(model) + trainer.data_connector.prepare_data() assert dm.random_full is not None # global rank = 1 (False) @@ -83,14 +81,14 @@ def test_can_prepare_data(local_rank, node_rank): local_rank.return_value = 0 assert not trainer.data_connector.can_prepare_data() - trainer.data_connector.prepare_data(model) + trainer.data_connector.prepare_data() assert dm.random_full is None node_rank.return_value = 0 local_rank.return_value = 1 assert not trainer.data_connector.can_prepare_data() - trainer.data_connector.prepare_data(model) + trainer.data_connector.prepare_data() assert dm.random_full is None # 2 dm From 1b88471c1605f33358f7117eb5492dbc28bcc35b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 23 Jul 2021 13:50:49 +0000 Subject: [PATCH 19/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/governance.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/governance.rst b/docs/source/governance.rst index 5c29f7d0da544..4114ccdb8a818 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -39,7 +39,7 @@ Board Alumni ------ -- Jeff Yang (`ydcjeff `_) +- Jeff Yang (`ydcjeff `_) - Jeff Ling (`jeffling `_) - Teddy Koker (`teddykoker `_) - Nate Raw (`nateraw `_) From f5b3ea8dd96cef9e6cdf69c2ce48f713c1fd18cc Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 5 Aug 2021 16:24:03 +0200 Subject: [PATCH 20/29] Minor changes --- pytorch_lightning/utilities/distributed.py | 2 +- tests/trainer/connectors/test_callback_connector.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 71292cf8a75b2..f6b76eeda4f0d 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -16,7 +16,7 @@ import os from functools import wraps from platform import python_version -from typing import Any, Callable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch from torch.nn.parallel.distributed import DistributedDataParallel diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index e8f6cd3b3adc9..edc2d56fa1a67 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -138,7 +138,6 @@ def test_attach_model_callbacks_override_info(caplog): """Test that the logs contain the info about overriding callbacks returned by configure_callbacks.""" trainer = Trainer(checkpoint_callback=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()]) trainer.call_hook = Mock(return_value=[LearningRateMonitor(), EarlyStopping()]) - cb_connector = CallbackConnector(trainer) with caplog.at_level(logging.INFO): cb_connector._attach_model_callbacks() From 8390f20253c3a7b4a3de73b623fa750a8e53c4d0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 6 Aug 2021 12:46:52 +0200 Subject: [PATCH 21/29] Remove reference --- tests/trainer/connectors/test_callback_connector.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index edc2d56fa1a67..512df307f6450 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -37,10 +37,7 @@ def test_checkpoint_callbacks_are_last(tmpdir): progress_bar = ProgressBar() # no model callbacks - model = Mock() - model.configure_callbacks.return_value = [] trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2]) - trainer.model = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2] @@ -87,8 +84,10 @@ def test_attach_model_callbacks(): """Test that the callbacks defined in the model and through Trainer get merged correctly.""" def assert_composition(trainer_callbacks, model_callbacks, expected): + model = Mock() + model.configure_callbacks.return_value = model_callbacks trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks) - trainer.call_hook = Mock(return_value=model_callbacks) + trainer.model = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == expected From 2edeb4f89966a1bf360e645896a6ed52cf730fc6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 6 Aug 2021 12:47:04 +0200 Subject: [PATCH 22/29] Keep is_overridden check --- pytorch_lightning/trainer/trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5a71defcfd925..fd6ef13bf7b27 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1239,9 +1239,10 @@ def call_hook( # next call hook in lightningModule output = None - model_fx = getattr(pl_module, hook_name, None) - if callable(model_fx): - output = model_fx(*args, **kwargs) + if is_overridden(hook_name, pl_module): + model_fx = getattr(pl_module, hook_name) + if callable(model_fx): + output = model_fx(*args, **kwargs) # call the accelerator hook if hook_name not in ("setup", "teardown") and hasattr(self.accelerator, hook_name): From b3bed31bbe1017ef50cb29de2edac5d1db42edbf Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 6 Aug 2021 12:50:09 +0200 Subject: [PATCH 23/29] Ooops --- tests/trainer/connectors/test_callback_connector.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 512df307f6450..9734198855077 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -84,10 +84,8 @@ def test_attach_model_callbacks(): """Test that the callbacks defined in the model and through Trainer get merged correctly.""" def assert_composition(trainer_callbacks, model_callbacks, expected): - model = Mock() - model.configure_callbacks.return_value = model_callbacks trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks) - trainer.model = model + trainer.call_hook = Mock(return_value=model_callbacks) cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == expected From 9a358315066c17b0d2fea98fc1a0a90fbfa7c36c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 6 Aug 2021 13:06:10 +0200 Subject: [PATCH 24/29] Remove stack inspection --- pytorch_lightning/core/lightning.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 4befbd7532e63..ac1c235bafd32 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -406,16 +406,14 @@ def log( results = self.trainer._results if results is None: - caller = inspect.stack()[1][3] raise MisconfigurationException( - f"You are trying to `self.log()` inside `{caller}` but the loop `ResultCollection` is not registered" + "You are trying to `self.log()` but the loop `ResultCollection` is not registered" " yet. This is most likely because you are trying to log in a `predict` hook," " but it doesn't support logging." ) if self._current_fx_name is None: - caller = inspect.stack()[1][3] raise MisconfigurationException( - f"You are trying to `self.log()` inside `{caller}` but it is not managed by the `Trainer` control flow" + "You are trying to `self.log()` but it is not managed by the `Trainer` control flow" ) FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) From 47ca3bf836c0b8170b8ea7dab7fa51b379899b40 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 6 Aug 2021 13:27:33 +0200 Subject: [PATCH 25/29] Revert test changes. Remove mock usage --- .../connectors/test_callback_connector.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 9734198855077..43158865f9e75 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from unittest.mock import Mock import torch -from pytorch_lightning import Callback, Trainer +from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.callbacks import ( EarlyStopping, GradientAccumulationScheduler, @@ -36,15 +35,24 @@ def test_checkpoint_callbacks_are_last(tmpdir): lr_monitor = LearningRateMonitor() progress_bar = ProgressBar() - # no model callbacks + # no model reference trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2]) cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2] - trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)]) + # no model callbacks + model = LightningModule() + model.configure_callbacks = lambda: [] + trainer.model = model + cb_connector._attach_model_callbacks() + assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2] + # with model-specific callbacks that substitute ones in Trainer - trainer.call_hook = Mock(return_value=[checkpoint1, early_stopping, checkpoint2]) + model = LightningModule() + model.configure_callbacks = lambda: [checkpoint1, early_stopping, checkpoint2] + trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)]) + trainer.model = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, checkpoint1, checkpoint2] @@ -84,8 +92,10 @@ def test_attach_model_callbacks(): """Test that the callbacks defined in the model and through Trainer get merged correctly.""" def assert_composition(trainer_callbacks, model_callbacks, expected): + model = LightningModule() + model.configure_callbacks = lambda: model_callbacks trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks) - trainer.call_hook = Mock(return_value=model_callbacks) + trainer.model = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == expected @@ -133,8 +143,10 @@ def assert_composition(trainer_callbacks, model_callbacks, expected): def test_attach_model_callbacks_override_info(caplog): """Test that the logs contain the info about overriding callbacks returned by configure_callbacks.""" + model = LightningModule() + model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping()] trainer = Trainer(checkpoint_callback=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()]) - trainer.call_hook = Mock(return_value=[LearningRateMonitor(), EarlyStopping()]) + trainer.model = model cb_connector = CallbackConnector(trainer) with caplog.at_level(logging.INFO): cb_connector._attach_model_callbacks() From 8172cc9b2bc8f3e02cb2e80a7b380ab99ac95870 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 6 Aug 2021 14:28:55 +0200 Subject: [PATCH 26/29] Check no trainer --- pytorch_lightning/core/lightning.py | 5 +++++ tests/trainer/logging_/test_logger_connector.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ac1c235bafd32..0d2d7d622379f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -404,6 +404,11 @@ def log( on_step = self.__auto_choose_log_on_step(on_step) on_epoch = self.__auto_choose_log_on_epoch(on_epoch) + if self.trainer is None: + raise MisconfigurationException( + "You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet." + " This is most likely because you logging before the model is passed to the `Trainer`" + ) results = self.trainer._results if results is None: raise MisconfigurationException( diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index a1e4205345368..ed7711b32ffda 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -209,6 +209,7 @@ def call(hook, fn, *args, **kwargs): def test_fx_validator_integration(tmpdir): """Tries to log inside all `LightningModule` and `Callback` hooks to check any expected errors""" not_supported = { + None: "`self.trainer` reference is not registered", "on_before_accelerator_backend_setup": "You can't", "setup": "You can't", "configure_sharded_model": "You can't", @@ -233,6 +234,10 @@ def test_fx_validator_integration(tmpdir): "summarize": "not managed by the `Trainer", } model = HookedModel(not_supported) + + with pytest.raises(MisconfigurationException, match=not_supported[None]): + model.log("foo", 1) + callback = HookedCallback(not_supported) trainer = Trainer( default_root_dir=tmpdir, From dc7e12e77c9308b018ba7145a2170e066609f7b2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 6 Aug 2021 14:30:08 +0200 Subject: [PATCH 27/29] Wording --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 0d2d7d622379f..242ea9be4a179 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -407,7 +407,7 @@ def log( if self.trainer is None: raise MisconfigurationException( "You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet." - " This is most likely because you logging before the model is passed to the `Trainer`" + " This is most likely because the model hasn't been passed to the `Trainer`" ) results = self.trainer._results if results is None: From 8123c435cd4e437b3b30477309cfb70367a6bce7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 6 Aug 2021 14:30:53 +0200 Subject: [PATCH 28/29] Period --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 242ea9be4a179..b59a4c738d299 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -414,7 +414,7 @@ def log( raise MisconfigurationException( "You are trying to `self.log()` but the loop `ResultCollection` is not registered" " yet. This is most likely because you are trying to log in a `predict` hook," - " but it doesn't support logging." + " but it doesn't support logging" ) if self._current_fx_name is None: raise MisconfigurationException( From 87e81e2fdbbe09f65c6ec93bc4c0c3da38910603 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 9 Aug 2021 18:01:56 +0200 Subject: [PATCH 29/29] Revert "Keep is_overridden check" This reverts commit 2edeb4f89966a1bf360e645896a6ed52cf730fc6. --- pytorch_lightning/trainer/trainer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9c569eda8cc84..06ac730a680d4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1238,10 +1238,9 @@ def call_hook( # next call hook in lightningModule output = None - if is_overridden(hook_name, pl_module): - model_fx = getattr(pl_module, hook_name) - if callable(model_fx): - output = model_fx(*args, **kwargs) + model_fx = getattr(pl_module, hook_name, None) + if callable(model_fx): + output = model_fx(*args, **kwargs) # call the accelerator hook if hook_name not in ("setup", "teardown") and hasattr(self.accelerator, hook_name):