Skip to content

Commit

Permalink
fix plateau scheduler stepping on incomplete epoch (#8861)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Aug 13, 2021
1 parent fec4f28 commit 4b6aaee
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed truncated backprop through time enablement when set as a property on the LightningModule and not the Trainer ([#8804](https://github.com/PyTorchLightning/pytorch-lightning/pull/8804/))


- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861))


## [1.4.0] - 2021-07-27

### Added
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
self.trainer.call_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()

self.update_lr_schedulers("epoch", update_plateau_schedulers=True)
if self._num_training_batches_reached(self.is_last_batch):
self.update_lr_schedulers("epoch", update_plateau_schedulers=True)

epoch_output = self._epoch_output
# free memory
Expand Down
41 changes: 27 additions & 14 deletions tests/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -79,7 +78,7 @@ def test_reducelronplateau_with_no_monitor_raises(tmpdir):
"""
Test exception when a ReduceLROnPlateau is used with no monitor
"""
model = EvalModelTemplate()
model = BoringModel()
optimizer = optim.Adam(model.parameters())
model.configure_optimizers = lambda: ([optimizer], [optim.lr_scheduler.ReduceLROnPlateau(optimizer)])
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
Expand All @@ -93,7 +92,7 @@ def test_reducelronplateau_with_no_monitor_in_lr_scheduler_dict_raises(tmpdir):
"""
Test exception when lr_scheduler dict has a ReduceLROnPlateau with no monitor
"""
model = EvalModelTemplate()
model = BoringModel()
optimizer = optim.Adam(model.parameters())
model.configure_optimizers = lambda: {
"optimizer": optimizer,
Expand Down Expand Up @@ -376,33 +375,47 @@ def configure_optimizers(self):
trainer.fit(model)


def test_lr_scheduler_strict(tmpdir):
@pytest.mark.parametrize("complete_epoch", [True, False])
@mock.patch("torch.optim.lr_scheduler.ReduceLROnPlateau.step")
def test_lr_scheduler_strict(step_mock, tmpdir, complete_epoch):
"""
Test "strict" support in lr_scheduler dict
"""
model = EvalModelTemplate()
model = BoringModel()
optimizer = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
max_epochs = 1 if complete_epoch else None
max_steps = None if complete_epoch else 1
trainer = Trainer(default_root_dir=tmpdir, max_epochs=max_epochs, max_steps=max_steps)

model.configure_optimizers = lambda: {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": scheduler, "monitor": "giraffe", "strict": True},
}
with pytest.raises(
MisconfigurationException,
match=r"ReduceLROnPlateau conditioned on metric .* which is not available\. Available metrics are:",
):

if complete_epoch:
with pytest.raises(
MisconfigurationException,
match=r"ReduceLROnPlateau conditioned on metric .* which is not available\. Available metrics are:",
):
trainer.fit(model)
else:
trainer.fit(model)

step_mock.assert_not_called()

model.configure_optimizers = lambda: {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": scheduler, "monitor": "giraffe", "strict": False},
}
with pytest.warns(
RuntimeWarning, match=r"ReduceLROnPlateau conditioned on metric .* which is not available but strict"
):
trainer.fit(model)

if complete_epoch:
with pytest.warns(
RuntimeWarning, match=r"ReduceLROnPlateau conditioned on metric .* which is not available but strict"
):
trainer.fit(model)

step_mock.assert_not_called()


def test_unknown_configure_optimizers_raises(tmpdir):
Expand Down

0 comments on commit 4b6aaee

Please sign in to comment.