Skip to content

Commit

Permalink
Fix: passing wrong strings for scheduler interval doesn't throw an er…
Browse files Browse the repository at this point in the history
…ror (#5923)

* Raise if scheduler interval not 'step' or 'epoch'

* Add test for unknown 'interval' value in scheduler

* Use BoringModel instead of EvalModelTemplate

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Fix import order

* Apply yapf in test_datamodules

* Add missing imports to test_datamodules

* Fix too long comment

* Update pytorch_lightning/trainer/optimizers.py

* Fix unused imports and exception message

* Fix failing test

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
3 people authored Feb 12, 2021
1 parent ae19c97 commit 309ce7a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
raise MisconfigurationException(
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
)
if 'interval' in scheduler and scheduler['interval'] not in ('step', 'epoch'):
raise MisconfigurationException(
f'The "interval" key in lr scheduler dict must be "step" or "epoch"'
f' but is "{scheduler["interval"]}"'
)

scheduler['reduce_on_plateau'] = isinstance(
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau
)
Expand Down
18 changes: 18 additions & 0 deletions tests/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,24 @@ def test_unknown_configure_optimizers_raises(tmpdir):
trainer.fit(model)


def test_lr_scheduler_with_unknown_interval_raises(tmpdir):
"""
Test exception when lr_scheduler dict has unknown interval param value
"""
model = BoringModel()
optimizer = torch.optim.Adam(model.parameters())
model.configure_optimizers = lambda: {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, 1),
'interval': "incorrect_unknown_value"
},
}
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.raises(MisconfigurationException, match=r'The "interval" key in lr scheduler dict must be'):
trainer.fit(model)


def test_lr_scheduler_with_extra_keys_warns(tmpdir):
"""
Test warning when lr_scheduler dict has extra keys
Expand Down

0 comments on commit 309ce7a

Please sign in to comment.