diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e790c38cfcd6..384fb6a20e1a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -153,6 +153,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed truncated backprop through time enablement when set as a property on the LightningModule and not the Trainer ([#8804](https://github.com/PyTorchLightning/pytorch-lightning/pull/8804/)) +- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861)) + + ## [1.4.0] - 2021-07-27 ### Added diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 52b31b67edfa9..213be0de21e6d 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -227,7 +227,8 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]: self.trainer.call_hook("on_epoch_end") self.trainer.logger_connector.on_epoch_end() - self.update_lr_schedulers("epoch", update_plateau_schedulers=True) + if self._num_training_batches_reached(self.is_last_batch): + self.update_lr_schedulers("epoch", update_plateau_schedulers=True) epoch_output = self._epoch_output # free memory diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index c0541f1c142b3..3c8a3d5ae8e68 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -20,7 +20,6 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import EvalModelTemplate from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -79,7 +78,7 @@ def test_reducelronplateau_with_no_monitor_raises(tmpdir): """ Test exception when a ReduceLROnPlateau is used with no monitor """ - model = EvalModelTemplate() + model = BoringModel() optimizer = optim.Adam(model.parameters()) model.configure_optimizers = lambda: ([optimizer], [optim.lr_scheduler.ReduceLROnPlateau(optimizer)]) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) @@ -93,7 +92,7 @@ def test_reducelronplateau_with_no_monitor_in_lr_scheduler_dict_raises(tmpdir): """ Test exception when lr_scheduler dict has a ReduceLROnPlateau with no monitor """ - model = EvalModelTemplate() + model = BoringModel() optimizer = optim.Adam(model.parameters()) model.configure_optimizers = lambda: { "optimizer": optimizer, @@ -376,33 +375,47 @@ def configure_optimizers(self): trainer.fit(model) -def test_lr_scheduler_strict(tmpdir): +@pytest.mark.parametrize("complete_epoch", [True, False]) +@mock.patch("torch.optim.lr_scheduler.ReduceLROnPlateau.step") +def test_lr_scheduler_strict(step_mock, tmpdir, complete_epoch): """ Test "strict" support in lr_scheduler dict """ - model = EvalModelTemplate() + model = BoringModel() optimizer = optim.Adam(model.parameters()) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + max_epochs = 1 if complete_epoch else None + max_steps = None if complete_epoch else 1 + trainer = Trainer(default_root_dir=tmpdir, max_epochs=max_epochs, max_steps=max_steps) model.configure_optimizers = lambda: { "optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "giraffe", "strict": True}, } - with pytest.raises( - MisconfigurationException, - match=r"ReduceLROnPlateau conditioned on metric .* which is not available\. Available metrics are:", - ): + + if complete_epoch: + with pytest.raises( + MisconfigurationException, + match=r"ReduceLROnPlateau conditioned on metric .* which is not available\. Available metrics are:", + ): + trainer.fit(model) + else: trainer.fit(model) + step_mock.assert_not_called() + model.configure_optimizers = lambda: { "optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "giraffe", "strict": False}, } - with pytest.warns( - RuntimeWarning, match=r"ReduceLROnPlateau conditioned on metric .* which is not available but strict" - ): - trainer.fit(model) + + if complete_epoch: + with pytest.warns( + RuntimeWarning, match=r"ReduceLROnPlateau conditioned on metric .* which is not available but strict" + ): + trainer.fit(model) + + step_mock.assert_not_called() def test_unknown_configure_optimizers_raises(tmpdir):