Skip to content

Commit 16633a6

Browse files
bug fix: restore_optimizers correctly handles non-mapping values in optimizer.state.values() (#11757)
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
1 parent 445f0e6 commit 16633a6

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
573573
- The `RichProgressBar` now correctly shows the `on_epoch` logged values on train epoch end ([#11689](https://github.com/PyTorchLightning/pytorch-lightning/pull/11689))
574574

575575

576+
- Fixed `restore_optimizers` for mapping states ([#11757](https://github.com/PyTorchLightning/pytorch-lightning/pull/11757))
577+
578+
576579
- Fixed check for available modules ([#11526](https://github.com/PyTorchLightning/pytorch-lightning/pull/11526))
577580

578581

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,13 @@ def restore_optimizers(self) -> None:
282282
# move optimizer to GPU 1 weight at a time
283283
# avoids OOM
284284
if self.trainer.root_gpu is not None:
285-
for state in optimizer.state.values():
286-
for k, v in state.items():
287-
if isinstance(v, torch.Tensor):
288-
state[k] = v.cuda(self.trainer.root_gpu)
285+
for param, state in optimizer.state.items():
286+
if isinstance(state, dict):
287+
for k, v in state.items():
288+
if isinstance(v, torch.Tensor):
289+
state[k] = v.cuda(self.trainer.root_gpu)
290+
elif isinstance(state, torch.Tensor):
291+
optimizer.state[param] = state.cuda(self.trainer.root_gpu)
289292

290293
def restore_lr_schedulers(self) -> None:
291294
"""Restores the learning rate scheduler states from the pre-loaded checkpoint."""

0 commit comments

Comments
 (0)