diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 79bd3438f95d0..4b98fd19633e3 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -148,6 +148,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for returning an object of type `Mapping` from `LightningModule.training_step()` ([#18657](https://github.com/Lightning-AI/lightning/pull/18657)) +- Added the hook `LightningModule.on_validation_model_zero_grad()` to allow overriding the behavior of zeroing the gradients before entering the validation loop ([#18710](https://github.com/Lightning-AI/lightning/pull/18710)) + + ### Changed - Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309)) @@ -289,6 +292,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed numerical issues when reducing values in low precision with `self.log` ([#18686](https://github.com/Lightning-AI/lightning/pull/18686)) +- Fixed an issue that would cause the gradients to be erased if validation happened in the middle of a gradient accumulation phase ([#18710](https://github.com/Lightning-AI/lightning/pull/18710)) + ## [2.0.9] - 2023-09-14 diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index b1a8404e21e78..1f8ff226c4a43 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -19,6 +19,7 @@ from torch import Tensor from torch.optim.optimizer import Optimizer +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from lightning.pytorch.utilities import move_data_to_device from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS @@ -151,6 +152,11 @@ def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: in """ + def on_validation_model_zero_grad(self) -> None: + """Called by the training loop to release gradients before entering the validation loop.""" + zero_grad_kwargs = {} if _TORCH_GREATER_EQUAL_2_0 else {"set_to_none": True} + self.zero_grad(**zero_grad_kwargs) + def on_validation_model_eval(self) -> None: """Sets the model to eval during the val loop.""" self.trainer.model.eval() diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 4c7d488d0760a..269beede34205 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -239,9 +239,7 @@ def on_run_start(self) -> None: """Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start`` hooks.""" self._verify_dataloader_idx_requirement() - self._on_evaluation_model_eval() - self.trainer.lightning_module.zero_grad() self._on_evaluation_start() self._on_evaluation_epoch_start() diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 4fd27866c8fdb..b4c676024a355 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -191,10 +191,7 @@ def reset(self) -> None: def on_run_start(self) -> None: """Calls ``_on_predict_model_eval``, ``_on_predict_start`` and ``_on_predict_epoch_start`` hooks.""" self._verify_dataloader_idx_requirement() - - trainer = self.trainer - call._call_lightning_module_hook(trainer, "on_predict_model_eval") - trainer.lightning_module.zero_grad() + call._call_lightning_module_hook(self.trainer, "on_predict_model_eval") self._on_predict_start() self._on_predict_epoch_start() diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 272c27936a192..8b925d8f8311d 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -277,6 +277,11 @@ def on_advance_end(self, data_fetcher: _DataFetcher) -> None: self.trainer.validating = True # save and reset this state in case validation runs inside training loop (val_check_interval<1.0) first_loop_iter = self.trainer._logger_connector._first_loop_iter + + if not self._should_accumulate(): + # clear gradients to not leave any unused memory during validation + call._call_lightning_module_hook(self.trainer, "on_validation_model_zero_grad") + self.val_loop.run() self.trainer.training = True self.trainer._logger_connector._first_loop_iter = first_loop_iter diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py index e9d5280d6fd5d..545749bc5b321 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py @@ -141,6 +141,7 @@ class _LogOptions(TypedDict): "test_dataloader": None, "prepare_data": None, "configure_callbacks": None, + "on_validation_model_zero_grad": None, "on_validation_model_eval": None, "on_test_model_eval": None, "on_validation_model_train": None, diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index aaec128581448..0df54710dc117 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -1022,6 +1022,9 @@ def _run_stage(self) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: # wait for all to join if on distributed self.strategy.barrier("run-stage") + zero_grad_kwargs = {} if _TORCH_GREATER_EQUAL_2_0 else {"set_to_none": True} + self.lightning_module.zero_grad(**zero_grad_kwargs) + if self.evaluating: return self._evaluation_loop.run() if self.predicting: diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index a23e602366c78..120d7949f658b 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -851,3 +851,34 @@ def _get_iterator(self): 3, # teardown on epoch 2, workers from epoch 2 get destroyed ] assert val_dataloader.shutdown_workers_epochs == expected + + +def test_validation_during_gradient_accumulation_window(tmp_path): + """Test that gradients don't get erased when the validation interval falls within the gradient accumulation + phase.""" + + class ValidationModel(BoringModel): + def on_validation_start(self): + batch_idx = self.trainer.fit_loop.epoch_loop.batch_progress.current.completed + grad_expected = batch_idx % self.trainer.accumulate_grad_batches != 0 + if grad_expected: + assert batch_idx in (2, 4) + assert all(p.grad is not None for p in self.parameters()) + else: + assert batch_idx == 6 + assert all(p.grad is None for p in self.parameters()) + self.ran_assert = True + + model = ValidationModel() + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=6, + limit_val_batches=1, + accumulate_grad_batches=3, + # validation happens in the middle of the first two accumulations, and at the end of the third + val_check_interval=2, + max_epochs=1, + num_sanity_val_steps=0, + ) + trainer.fit(model) + assert model.ran_assert diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 91a2666babcb4..1c546c69394ab 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -18,6 +18,7 @@ import pytest import torch +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__ from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset from torch import Tensor @@ -465,11 +466,11 @@ def training_step(self, batch, batch_idx): {"name": "configure_optimizers"}, {"name": "Callback.on_fit_start", "args": (trainer, model)}, {"name": "on_fit_start"}, + {"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})}, {"name": "Callback.on_sanity_check_start", "args": (trainer, model)}, {"name": "val_dataloader"}, {"name": "train", "args": (False,)}, {"name": "on_validation_model_eval"}, - {"name": "zero_grad"}, {"name": "Callback.on_validation_start", "args": (trainer, model)}, {"name": "on_validation_start"}, *model._eval_epoch("validation", trainer, model, val_batches, "x", device=device), @@ -486,9 +487,10 @@ def training_step(self, batch, batch_idx): {"name": "Callback.on_train_epoch_start", "args": (trainer, model)}, {"name": "on_train_epoch_start"}, *model._train_batch(trainer, model, train_batches, device=device, **kwargs), + {"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})}, + {"name": "on_validation_model_zero_grad"}, {"name": "train", "args": (False,)}, {"name": "on_validation_model_eval"}, - {"name": "zero_grad"}, {"name": "Callback.on_validation_start", "args": (trainer, model)}, {"name": "on_validation_start"}, *model._eval_epoch("validation", trainer, model, val_batches, "x", device=device), @@ -566,6 +568,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir): {"name": "configure_optimizers"}, {"name": "Callback.on_fit_start", "args": (trainer, model)}, {"name": "on_fit_start"}, + {"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})}, {"name": "train_dataloader"}, {"name": "train", "args": (True,)}, {"name": "Callback.on_train_start", "args": (trainer, model)}, @@ -644,6 +647,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir): {"name": "configure_optimizers"}, {"name": "Callback.on_fit_start", "args": (trainer, model)}, {"name": "on_fit_start"}, + {"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})}, {"name": "train_dataloader"}, {"name": "train", "args": (True,)}, {"name": "Callback.on_train_start", "args": (trainer, model)}, @@ -690,7 +694,6 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader, {"name": f"{dataloader}_dataloader"}, {"name": "train", "args": (False,)}, {"name": f"on_{noun}_model_eval"}, - {"name": "zero_grad"}, {"name": f"Callback.on_{noun}_start", "args": (trainer, model)}, {"name": f"on_{noun}_start"}, *model._eval_epoch(noun, trainer, model, batches, key, trainer.strategy.root_device), @@ -705,6 +708,7 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader, {"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": verb}}, {"name": "setup", "kwargs": {"stage": verb}}, {"name": "configure_model"}, + {"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})}, *(hooks if batches else []), {"name": "Callback.teardown", "args": (trainer, model), "kwargs": {"stage": verb}}, {"name": "teardown", "kwargs": {"stage": verb}}, @@ -727,10 +731,10 @@ def test_trainer_model_hook_system_predict(tmpdir): {"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "predict"}}, {"name": "setup", "kwargs": {"stage": "predict"}}, {"name": "configure_model"}, + {"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})}, {"name": "predict_dataloader"}, {"name": "train", "args": (False,)}, {"name": "on_predict_model_eval"}, - {"name": "zero_grad"}, {"name": "Callback.on_predict_start", "args": (trainer, model)}, {"name": "on_predict_start"}, {"name": "Callback.on_predict_epoch_start", "args": (trainer, model)}, diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index b2baf7e4635f6..d7d29a857a9be 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -211,6 +211,7 @@ def test_fx_validator_integration(tmpdir): "on_sanity_check_end": "You can't", "prepare_data": "You can't", "configure_callbacks": "You can't", + "on_validation_model_zero_grad": "You can't", "on_validation_model_eval": "You can't", "on_validation_model_train": "You can't", "lr_scheduler_step": "You can't",