diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 80af2fbc1a9bb..996b08522d166 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -92,19 +92,19 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._device_type = None #: True if using amp - self.use_amp = False + self.use_amp: bool = False #: The precision used - self.precision = 32 + self.precision: int = 32 # optionally can be set by user self._example_input_array = None self._datamodule = None self._results: Optional[Result] = None - self._current_fx_name = '' - self._running_manual_backward = False - self._current_hook_fx_name = None - self._current_dataloader_idx = None + self._current_fx_name: str = '' + self._running_manual_backward: bool = False + self._current_hook_fx_name: Optional[str] = None + self._current_dataloader_idx: Optional[int] = None self._automatic_optimization: bool = True self._param_requires_grad_state = dict() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fa42a75c24829..3731a6d0bd8cb 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1151,6 +1151,10 @@ def call_teardown_hook(self, model: LightningModule) -> None: self.teardown(stage=state) model.teardown(stage=state) + model._current_fx_name = "" + model._current_hook_fx_name = None + model._current_dataloader_idx = None + def _reset_result_and_set_hook_fx_name(self, hook_name: str) -> bool: # on_before_zero_grad is called within training_step if "batch_start" in hook_name or hook_name in ("on_before_zero_grad", "on_after_backward"): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3a35912fa7936..0fb060ef31903 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2047,3 +2047,33 @@ def test_fit_test_synchronization(tmpdir): trainer.fit(model) assert os.path.exists(checkpoint.best_model_path), f'Could not find checkpoint at rank {trainer.global_rank}' trainer.test() + + +def test_module_current_fx_attributes_reset(tmpdir): + """ Ensure that lightning module's attributes related to current hook fx are reset at the end of execution. """ + model = BoringModel() + model.validation_step = None + model.training_epoch_end = None + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + checkpoint_callback=False, + logger=False, + limit_val_batches=0, + ) + trainer.fit(model) + assert model._current_fx_name == "", f"_current_fx_name not reset after fit: {model._current_fx_name}" + assert ( + model._current_hook_fx_name is None + ), f"_current_hook_fx_name not reset after fit: {model._current_hook_fx_name}" + assert ( + model._current_dataloader_idx is None + ), f"_current_dataloader_idx not reset after fit: {model._current_dataloader_idx}" + trainer.test(model) + assert model._current_fx_name == "", f"_current_fx_name not reset after test: {model._current_fx_name}" + assert ( + model._current_hook_fx_name is None + ), f"_current_hook_fx_name not reset after test: {model._current_hook_fx_name}" + assert ( + model._current_dataloader_idx is None + ), f"_current_dataloader_idx not reset after test: {model._current_dataloader_idx}"