diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 9655d7b59ef34..3500fe9303067 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -97,6 +97,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added automatic process cleanup to avoid zombie child processes and stalls when exceptions are raised ([#18218](https://github.com/Lightning-AI/lightning/pull/18218)) +- Added `LightningOptimizer.refresh()` to update the `__dict__` in case the optimizer it wraps has changed its internal state ([#18280](https://github.com/Lightning-AI/lightning/pull/18280)) + + ### Changed - Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309)) @@ -203,6 +206,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Ensure that the closure running inside the optimizer step has gradients enabled, even if the optimizer step has it disabled ([#18268](https://github.com/Lightning-AI/lightning/pull/18268)) +- Fixed an issue that could cause the `LightningOptimizer` wrapper returned by `LightningModule.optimizers()` have different internal state than the optimizer it wraps ([#18280](https://github.com/Lightning-AI/lightning/pull/18280)) + + + ## [2.0.5] - 2023-07-07 ### Fixed diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index a4237a7352682..00f2f6bec401c 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -158,6 +158,8 @@ def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS: opts: MODULE_OPTIMIZERS = self._fabric_optimizers elif use_pl_optimizer: opts = self.trainer.strategy._lightning_optimizers + for opt in opts: + opt.refresh() else: opts = self.trainer.optimizers diff --git a/src/lightning/pytorch/core/optimizer.py b/src/lightning/pytorch/core/optimizer.py index e90ff0be84a6b..a9ac0797b599a 100644 --- a/src/lightning/pytorch/core/optimizer.py +++ b/src/lightning/pytorch/core/optimizer.py @@ -35,12 +35,15 @@ def do_nothing_closure() -> None: class LightningOptimizer: """This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across - accelerators, AMP, accumulate_grad_batches.""" + accelerators, AMP, accumulate_grad_batches. + + Note: The purpose of this wrapper is only to define new methods and redirect the `.step()` call. The internal + state ``__dict__`` is not kept in sync with the internal state of the original optimizer, but the Trainer never + relies on the internal state of the wrapper. + + """ def __init__(self, optimizer: Optimizer): - # copy most of the `Optimizer` methods into this instance. `__del__` is skipped in case the optimizer has - # implemented custom logic which we would not want to call on destruction of the `LightningOptimizer` - self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")} self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer @@ -49,20 +52,12 @@ def __init__(self, optimizer: Optimizer): self._on_before_step = do_nothing_closure self._on_after_step = do_nothing_closure + self.refresh() + @property def optimizer(self) -> Optimizer: return self._optimizer - @classmethod - def _to_lightning_optimizer( - cls, optimizer: Union[Optimizer, "LightningOptimizer"], strategy: "pl.strategies.Strategy" - ) -> "LightningOptimizer": - # the user could return a `LightningOptimizer` from `configure_optimizers`, see test: - # tests/core/test_lightning_optimizer.py::test_lightning_optimizer[False] - lightning_optimizer = optimizer if isinstance(optimizer, LightningOptimizer) else cls(optimizer) - lightning_optimizer._strategy = proxy(strategy) - return lightning_optimizer - @contextmanager def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]: """This function is just a helper for advanced users. @@ -86,6 +81,15 @@ def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]: yield lightning_module.untoggle_optimizer(self) + def refresh(self) -> None: + """Refreshes the ``__dict__`` so that it matches the internal states in the wrapped optimizer. + + This is only needed to present the user with an updated view in case they inspect the state of this wrapper. + """ + # copy most of the `Optimizer` methods into this instance. `__del__` is skipped in case the optimizer has + # implemented custom logic which we would not want to call on destruction of the `LightningOptimizer` + self.__dict__.update({k: v for k, v in self.optimizer.__dict__.items() if k not in ("step", "__del__")}) + def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> Any: """Performs a single optimization step (parameter update). @@ -161,6 +165,16 @@ def closure_dis(): return step_output + @classmethod + def _to_lightning_optimizer( + cls, optimizer: Union[Optimizer, "LightningOptimizer"], strategy: "pl.strategies.Strategy" + ) -> "LightningOptimizer": + # the user could return a `LightningOptimizer` from `configure_optimizers`, see test: + # tests/core/test_lightning_optimizer.py::test_lightning_optimizer[False] + lightning_optimizer = optimizer if isinstance(optimizer, LightningOptimizer) else cls(optimizer) + lightning_optimizer._strategy = proxy(strategy) + return lightning_optimizer + def _init_optimizers_and_lr_schedulers( model: "pl.LightningModule", diff --git a/tests/tests_pytorch/core/test_lightning_optimizer.py b/tests/tests_pytorch/core/test_lightning_optimizer.py index 3d0ee4d7a9498..f6c2593f3962b 100644 --- a/tests/tests_pytorch/core/test_lightning_optimizer.py +++ b/tests/tests_pytorch/core/test_lightning_optimizer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from unittest.mock import DEFAULT, Mock, patch import pytest @@ -160,6 +161,32 @@ def test_state(): assert optimizer.state == lightning_optimizer.state +def test_state_mutation(): + model = torch.nn.Linear(3, 4) + optimizer0 = torch.optim.Adam(model.parameters(), lr=0.1) + lightning_optimizer0 = LightningOptimizer(optimizer0) + + optimizer0.param_groups[0]["lr"] = 1.0 + assert lightning_optimizer0.param_groups[0]["lr"] == 1.0 + + # Load state into the unwrapped optimizer + state_dict0 = deepcopy(optimizer0.state_dict()) + optimizer1 = torch.optim.Adam(model.parameters(), lr=100) + lightning_optimizer1 = LightningOptimizer(optimizer1) + optimizer1.load_state_dict(state_dict0) + + # LightningOptimizer needs to be refreshed to see the new state + assert lightning_optimizer1.param_groups[0]["lr"] != 1.0 + lightning_optimizer1.refresh() + assert lightning_optimizer1.param_groups[0]["lr"] == 1.0 + + # Load state into wrapped optimizer + optimizer2 = torch.optim.Adam(model.parameters(), lr=100) + lightning_optimizer2 = LightningOptimizer(optimizer2) + lightning_optimizer2.load_state_dict(state_dict0) + assert lightning_optimizer2.param_groups[0]["lr"] == 1.0 + + def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad(tmpdir): """Test overriding zero_grad works in automatic_optimization.""" @@ -296,7 +323,15 @@ def test_lightning_optimizer_keeps_hooks(): def test_params_groups_and_state_are_accessible(tmpdir): class TestModel(BoringModel): + def on_train_start(self): + # Update the learning rate manually on the unwrapped optimizer + assert not isinstance(self.trainer.optimizers[0], LightningOptimizer) + self.trainer.optimizers[0].param_groups[0]["lr"] = 2.0 + def training_step(self, batch, batch_idx): + opt = self.optimizers() + assert opt.param_groups[0]["lr"] == 2.0 + loss = self.step(batch) self.__loss = loss return loss diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 47a9425b68961..f02fa4e4f67cd 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -509,16 +509,17 @@ def test_set_timeout(init_process_group_mock): @RunIf(min_torch="1.12") def test_fsdp_strategy_load_optimizer_states_multiple(): strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")]) + spec = torch.optim.Optimizer # More states than optimizers configured - strategy.optimizers = [Mock()] - checkpoint = {"optimizer_states": [Mock(), Mock()]} + strategy.optimizers = [Mock(spec=spec)] + checkpoint = {"optimizer_states": [Mock(spec=spec), Mock(spec=spec)]} with pytest.raises(RuntimeError, match="1 optimizers but the checkpoint contains 2 optimizers to load"): strategy.load_optimizer_state_dict(checkpoint) # Fewer states than optimizers configured - strategy.optimizers = [Mock(), Mock()] - checkpoint = {"optimizer_states": [Mock()]} + strategy.optimizers = [Mock(spec=spec), Mock(spec=spec)] + checkpoint = {"optimizer_states": [Mock(spec=spec)]} with pytest.raises(RuntimeError, match="2 optimizers but the checkpoint contains 1 optimizers to load"): strategy.load_optimizer_state_dict(checkpoint)