From c7d0f4c3a29bd5524e0b66f9196f123b64d1587a Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 5 Jan 2021 11:37:28 +0530 Subject: [PATCH] Add a check for optimizer attatched to lr_scheduler (#5338) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add a check for scheduler and optimizer * pep * Apply suggestions from code review Co-authored-by: Adrian Wälchli Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 3 +++ pytorch_lightning/trainer/optimizers.py | 9 +++++++++ tests/trainer/test_optimizers.py | 17 +++++++++++++++++ 3 files changed, 29 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9801f56f6f1bf..8af6984ac98ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,9 +9,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added + +- Added a check for optimizer attached to lr_scheduler ([#5338](https://github.com/PyTorchLightning/pytorch-lightning/pull/5338)) - Added `resume_from_checkpoint` accept non-existing file path ([#4402](https://github.com/PyTorchLightning/pytorch-lightning/pull/4402)) + ### Changed diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 479d401720261..974ee898ff00b 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -75,7 +75,9 @@ def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: ' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n' ' * A list of the previously described dict format, with an optional "frequency" key (int)' ) + lr_schedulers = self.configure_schedulers(lr_schedulers, monitor=monitor) + _validate_scheduler_optimizer(optimizers, lr_schedulers) return optimizers, lr_schedulers, optimizer_frequencies @@ -183,3 +185,10 @@ def zero_grad(self): def __repr__(self): return 'No Optimizer' + + +def _validate_scheduler_optimizer(optimizers, lr_schedulers): + if any(sch['scheduler'].optimizer not in optimizers for sch in lr_schedulers): + raise MisconfigurationException( + "Some schedulers are attatched with an optimizer that wasn't returned from `configure_optimizers`." + ) diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py index 52e085b2b7b8c..e9a422dfb4711 100644 --- a/tests/trainer/test_optimizers.py +++ b/tests/trainer/test_optimizers.py @@ -483,3 +483,20 @@ def test_lr_scheduler_with_no_actual_scheduler_raises(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) with pytest.raises(MisconfigurationException, match='The lr scheduler dict must have the key "scheduler"'): trainer.fit(model) + + +def test_invalid_optimizer_in_scheduler(tmpdir): + """ + Test exception when optimizer attatched to lr_schedulers wasn't returned + """ + class InvalidOptimizerModel(BoringModel): + def configure_optimizers(self): + opt1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + opt2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(opt2, step_size=1) + return [opt1], [lr_scheduler] + + model = InvalidOptimizerModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.raises(MisconfigurationException, match="attatched with an optimizer that wasn't returned"): + trainer.fit(model)