Skip to content

Commit

Permalink
Fix attribute error on _NotYetLoadedTensor after loading checkpoint…
Browse files Browse the repository at this point in the history
… into quantized model with `_lazy_load()` (Lightning-AI#20121)
  • Loading branch information
awaelchli authored and ammyk9 committed Aug 6, 2024
1 parent 55935ea commit 231228b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed an attribute error when loading a checkpoint into a quantized model using the `_lazy_load()` function ([#20121](https://github.com/Lightning-AI/lightning/pull/20121))


-

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/utilities/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __getattr__(self, name: str) -> Any:
return getattr(self.metatensor, name)

# materializing these is needed for quantization (see lit-gpt)
if name in {"contiguous", "cuda", "half"}:
if name in {"contiguous", "cuda", "half", "data"}:
return getattr(self._load_tensor(), name)

raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
Expand Down
8 changes: 8 additions & 0 deletions tests/tests_fabric/plugins/precision/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from lightning.fabric.connector import _Connector
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision
from lightning.fabric.utilities.init import _materialize_meta_tensors
from lightning.fabric.utilities.load import _lazy_load

from tests_fabric.helpers.runif import RunIf

Expand Down Expand Up @@ -264,3 +265,10 @@ def forward(self, x):
assert model.linear.weight.shape == (128, 1)
# Shapes match during forward (weight is being dequantized during forward)
model(torch.randn(2, 16, device=fabric.device))

# Test with lazy load (LitGPT uses this)
# TODO: Replace `_lazy_load` with `torch.load(..., mmap=True)` in LitGPT
state_dict = _lazy_load(tmp_path / "checkpoint.pt")
model.load_state_dict(state_dict)
assert model.linear.weight.dtype == torch.uint8
assert model.linear.weight.shape == (128, 1)

0 comments on commit 231228b

Please sign in to comment.