Skip to content

Commit

Permalink
Add a check for optimizer attatched to lr_scheduler (#5338)
Browse files Browse the repository at this point in the history
* add a check for scheduler and optimizer

* pep

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
rohitgr7 and awaelchli authored Jan 5, 2021
1 parent f740245 commit c7d0f4c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added


- Added a check for optimizer attached to lr_scheduler ([#5338](https://github.com/PyTorchLightning/pytorch-lightning/pull/5338))
- Added `resume_from_checkpoint` accept non-existing file path ([#4402](https://github.com/PyTorchLightning/pytorch-lightning/pull/4402))



### Changed


Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n'
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)

lr_schedulers = self.configure_schedulers(lr_schedulers, monitor=monitor)
_validate_scheduler_optimizer(optimizers, lr_schedulers)

return optimizers, lr_schedulers, optimizer_frequencies

Expand Down Expand Up @@ -183,3 +185,10 @@ def zero_grad(self):

def __repr__(self):
return 'No Optimizer'


def _validate_scheduler_optimizer(optimizers, lr_schedulers):
if any(sch['scheduler'].optimizer not in optimizers for sch in lr_schedulers):
raise MisconfigurationException(
"Some schedulers are attatched with an optimizer that wasn't returned from `configure_optimizers`."
)
17 changes: 17 additions & 0 deletions tests/trainer/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,3 +483,20 @@ def test_lr_scheduler_with_no_actual_scheduler_raises(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.raises(MisconfigurationException, match='The lr scheduler dict must have the key "scheduler"'):
trainer.fit(model)


def test_invalid_optimizer_in_scheduler(tmpdir):
"""
Test exception when optimizer attatched to lr_schedulers wasn't returned
"""
class InvalidOptimizerModel(BoringModel):
def configure_optimizers(self):
opt1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
opt2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(opt2, step_size=1)
return [opt1], [lr_scheduler]

model = InvalidOptimizerModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.raises(MisconfigurationException, match="attatched with an optimizer that wasn't returned"):
trainer.fit(model)

0 comments on commit c7d0f4c

Please sign in to comment.