diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 8ffc4721a9f9f..c0f623dda4730 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -35,7 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed an attribute error when loading a checkpoint into a quantized model using the `_lazy_load()` function ([#20121](https://github.com/Lightning-AI/lightning/pull/20121)) + - diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index 9862cc2bd981e..2fa5c6e96db21 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -160,7 +160,7 @@ def __getattr__(self, name: str) -> Any: return getattr(self.metatensor, name) # materializing these is needed for quantization (see lit-gpt) - if name in {"contiguous", "cuda", "half"}: + if name in {"contiguous", "cuda", "half", "data"}: return getattr(self._load_tensor(), name) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") diff --git a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py index d5311616828ee..b8b9020b201a7 100644 --- a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py @@ -22,6 +22,7 @@ from lightning.fabric.connector import _Connector from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision from lightning.fabric.utilities.init import _materialize_meta_tensors +from lightning.fabric.utilities.load import _lazy_load from tests_fabric.helpers.runif import RunIf @@ -264,3 +265,10 @@ def forward(self, x): assert model.linear.weight.shape == (128, 1) # Shapes match during forward (weight is being dequantized during forward) model(torch.randn(2, 16, device=fabric.device)) + + # Test with lazy load (LitGPT uses this) + # TODO: Replace `_lazy_load` with `torch.load(..., mmap=True)` in LitGPT + state_dict = _lazy_load(tmp_path / "checkpoint.pt") + model.load_state_dict(state_dict) + assert model.linear.weight.dtype == torch.uint8 + assert model.linear.weight.shape == (128, 1)