From e5dfdf34f958d2eddccb205558763ae7aeedc0cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sat, 16 Oct 2021 18:36:25 +0200 Subject: [PATCH] Avoid deprecation warning after #9901 (#9951) --- pytorch_lightning/accelerators/gpu.py | 1 + pytorch_lightning/plugins/training_type/ipu.py | 2 +- .../training_type/training_type_plugin.py | 2 +- pytorch_lightning/trainer/trainer.py | 17 ++++++++++++++++- tests/loops/test_training_loop.py | 2 +- 5 files changed, 20 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index b33903c2d60c9..44b29efe6f2bc 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -46,6 +46,7 @@ def setup(self, trainer: "pl.Trainer") -> None: return super().setup(trainer) def on_train_start(self) -> None: + super().on_train_start() # clear cache before training torch.cuda.empty_cache() diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index daa704e8a8243..b6728b0551081 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -285,7 +285,7 @@ def on_test_end(self): def on_predict_end(self): self._detach_models() - def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: # Updates optimizer stats if LR scheduler modified the optimizer state optimizer = self.lightning_module.trainer.optimizers[0] self.poptorch_models[RunningStage.TRAINING].setOptimizer(optimizer) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index cf36a3502702d..9c53069063a52 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -345,7 +345,7 @@ def on_predict_end(self): """Called when predict ends.""" pass - def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Called in the training loop before anything happens for that batch.""" pass diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index be0a7728edddc..e6d8ccde91d71 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1401,8 +1401,15 @@ def call_hook( if callable(model_fx): output = model_fx(*args, **kwargs) + # *Bad code alert* + # The `Accelerator` mostly calls the `TrainingTypePlugin` but some of those calls are deprecated. + # The following logic selectively chooses which hooks are called on each object. + # In the case of `setup` and `teardown`, the hooks on the `LightningModule` should not call the hooks of the + # same name in these objects as they are meant to be managed outside of the `LightningModule` lifecycle. + # All of this should be fixed by #8506 + # call the accelerator hook - if hook_name not in ("setup", "teardown") and hasattr(self.accelerator, hook_name): + if hook_name in ("on_train_start",) and hasattr(self.accelerator, hook_name): accelerator_hook = getattr(self.accelerator, hook_name) accelerator_output = accelerator_hook(*args, **kwargs) # Rely on the accelerator output if lightningModule hook returns nothing @@ -1410,6 +1417,14 @@ def call_hook( # todo: move this data parallel logic into the data parallel plugin output = accelerator_output if output is None else output + # call the ttp hook + if hook_name not in ("setup", "teardown", "on_train_start") and hasattr( + self.training_type_plugin, hook_name + ): + ttp_hook = getattr(self.training_type_plugin, hook_name) + ttp_output = ttp_hook(*args, **kwargs) + output = ttp_output if output is None else output + if pl_module: # restore current_fx when nested context pl_module._current_fx_name = prev_fx_name diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index d491db3bbc91c..ebfe0d4762806 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -86,7 +86,7 @@ def run_training(**trainer_kwargs): @pytest.mark.parametrize(["max_epochs", "batch_idx_"], [(2, 5), (3, 8), (4, 12)]) def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_, tmpdir): class CurrentModel(BoringModel): - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, batch, batch_idx): if batch_idx == batch_idx_: return -1