From 01109cdf0c44a150c262b65e70a7e1e64003cf93 Mon Sep 17 00:00:00 2001 From: "Xinyao(Alvin) Sun" Date: Sat, 22 May 2021 20:30:28 -0600 Subject: [PATCH] Fix/mismatched toggle optimizer (#7563) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: avoid potential mismatched toggling of optimzier Refs #7405 chore: update CHANGELOG [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix: resolve a confict chore: update changelog * feat: add a test that fails in master * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo in tests/trainer/optimization/test_multiple_optimizers.py Co-authored-by: ananthsub * Polish tests/trainer/optimization/test_multiple_optimizers.py Co-authored-by: Carlos Mocholí * Polish tests/trainer/optimization/test_multiple_optimizers.py Co-authored-by: Carlos Mocholí * fix: change placeholder in optimizer_step from positional args to keyword args Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ananthsub Co-authored-by: Carlos Mocholí --- CHANGELOG.md | 1 + pytorch_lightning/trainer/training_loop.py | 8 +-- .../optimization/test_multiple_optimizers.py | 65 +++++++++++++++++++ 3 files changed, 69 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 44914061b8275..9179711b5cd00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563) - Changed the `Trainer`'s `checkpoint_callback` argument to allow only boolean values ([#7539](https://github.com/PyTorchLightning/pytorch-lightning/pull/7539)) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 84d69765c7c36..a555146875eb5 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -724,7 +724,6 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi # ------------------- # calculate loss (train step + train step end) # ------------------- - # automatic_optimization=True: perform ddp sync only when performing optimizer_step # automatic_optimization=False: don't block synchronization here with self.block_ddp_sync_behaviour(): @@ -737,6 +736,9 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi else: if self.trainer.lightning_module.automatic_optimization: self.optimizer_step(optimizer, opt_idx, batch_idx, closure) + if len(self.trainer.optimizers) > 1: + # revert back to previous state + self.trainer.lightning_module.untoggle_optimizer(opt_idx) else: result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens) @@ -837,10 +839,6 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, "training_step returned None. If this was on purpose, ignore this warning..." ) - if len(self.trainer.optimizers) > 1: - # revert back to previous state - self.trainer.lightning_module.untoggle_optimizer(opt_idx) - return result def _check_finite(self, loss: torch.Tensor) -> None: diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py index 24b32c8725963..aba3b53248a57 100644 --- a/tests/trainer/optimization/test_multiple_optimizers.py +++ b/tests/trainer/optimization/test_multiple_optimizers.py @@ -168,3 +168,68 @@ def training_step(self, batch, batch_idx): with pytest.raises(ValueError, match='`training_step` is missing the `optimizer_idx`'): trainer.fit(TestModel()) + + +def test_custom_optimizer_step_with_multiple_optimizers(tmpdir): + """ + This tests ensures custom optimizer_step works, + even when optimizer.step is not called for a particular optimizer + """ + + class TestModel(BoringModel): + training_step_called = [0, 0] + optimizer_step_called = [0, 0] + + def __init__(self): + super().__init__() + self.layer_a = torch.nn.Linear(32, 2) + self.layer_b = torch.nn.Linear(32, 2) + + def configure_optimizers(self): + opt_a = torch.optim.SGD(self.layer_a.parameters(), lr=0.001) + opt_b = torch.optim.SGD(self.layer_b.parameters(), lr=0.001) + return opt_a, opt_b + + def training_step(self, batch, batch_idx, optimizer_idx): + self.training_step_called[optimizer_idx] += 1 + x = self.layer_a(batch[0]) if (optimizer_idx == 0) else self.layer_b(batch[0]) + loss = torch.nn.functional.mse_loss(x, torch.ones_like(x)) + return loss + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def optimizer_step( + self, + epoch, + batch_idx, + optimizer, + optimizer_idx, + optimizer_closure, + **_, + ): + # update first optimizer every step + if optimizer_idx == 0: + self.optimizer_step_called[optimizer_idx] += 1 + optimizer.step(closure=optimizer_closure) + + # update second optimizer every 2 steps + if optimizer_idx == 1: + if batch_idx % 2 == 0: + self.optimizer_step_called[optimizer_idx] += 1 + optimizer.step(closure=optimizer_closure) + + model = TestModel() + model.val_dataloader = None + + trainer = pl.Trainer( + default_root_dir=tmpdir, + limit_train_batches=4, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + assert model.training_step_called == [4, 2] + assert model.optimizer_step_called == [4, 2]