Skip to content

Commit

Permalink
Fix deriving default map location when there is extra state (#17812)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jun 12, 2023
1 parent 9ff7d71 commit 34de253
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed validation of parameters of `plugins.precision.MixedPrecisionPlugin` ([#17687](https://github.com/Lightning-AI/lightning/pull/17687))


- Fixed deriving default map location in `LightningModule.load_from_checkpoint` when there is extra state ([#17812](https://github.com/Lightning-AI/lightning/pull/17812))


## [2.0.3] - 2023-06-07

### Changed
Expand Down
10 changes: 6 additions & 4 deletions src/lightning/pytorch/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing import Any, Callable, Dict, IO, Optional, Type, Union
from warnings import warn

import torch
import yaml
from lightning_utilities.core.apply_func import apply_to_collection

Expand Down Expand Up @@ -86,13 +87,14 @@ def _load_from_checkpoint(
if issubclass(cls, pl.LightningDataModule):
return _load_state(cls, checkpoint, **kwargs)
if issubclass(cls, pl.LightningModule):
storage = _load_state(cls, checkpoint, strict=strict, **kwargs)
model = _load_state(cls, checkpoint, strict=strict, **kwargs)
state_dict = checkpoint["state_dict"]
if not state_dict:
raise ValueError(f"The state dict in {checkpoint_path!r} contains no parameters.")
map_location = list(state_dict.values())[0].device
assert isinstance(storage, pl.LightningModule)
return storage.to(map_location)

device = next((t for t in state_dict.values() if isinstance(t, torch.Tensor)), torch.tensor(0)).device
assert isinstance(model, pl.LightningModule)
return model.to(device)

raise NotImplementedError(f"Unsupported {cls}")

Expand Down
16 changes: 16 additions & 0 deletions tests/tests_pytorch/core/test_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,19 @@ def test_load_from_checkpoint_map_location_cpu_to_gpu(tmp_path, map_location):
create_boring_checkpoint(tmp_path, BoringModel(), accelerator="cpu")
model = BoringModel.load_from_checkpoint(f"{tmp_path}/checkpoint.ckpt", map_location=map_location)
assert model.device.type == "cuda"


@RunIf(min_cuda_gpus=1)
def test_load_from_checkpoint_device_placement_with_extra_state(tmp_path):
"""Test that the device gets chosen based on the device of the saved tensors in the checkpoint."""

class ExtraStateModel(BoringModel):
def get_extra_state(self):
return {"extra": "state"} # state without tensors

def set_extra_state(self, state):
pass

create_boring_checkpoint(tmp_path, ExtraStateModel(), accelerator="cuda")
model = ExtraStateModel.load_from_checkpoint(f"{tmp_path}/checkpoint.ckpt", map_location=None)
assert model.device.type == "cuda"

0 comments on commit 34de253

Please sign in to comment.