Skip to content

Commit

Permalink
Fix state dict loading in bitsandbytes plugin when checkpoint is alre…
Browse files Browse the repository at this point in the history
…ady quantized (#19886)

* bugfix

* add test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* add chlog

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
awaelchli and pre-commit-ci[bot] authored May 21, 2024
1 parent b1bb3f3 commit 7e87ce0
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) ([#19886](https://github.com/Lightning-AI/lightning/pull/19886))

-

Expand Down
12 changes: 6 additions & 6 deletions src/lightning/fabric/plugins/precision/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torc
"""Inplace quantize."""
if weight is None:
weight = self.weight.data
if weight.data.type == torch.int8:
# already quantized
return
if weight.data.dtype == torch.int8:
# already quantized
return
assert isinstance(self.weight, bnb.nn.Int8Params)
self.weight = self.quantize(self.weight, weight, device)

Expand Down Expand Up @@ -317,9 +317,9 @@ def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torc
"""Inplace quantize."""
if weight is None:
weight = self.weight.data
if weight.data.type == torch.uint8:
# already quantized
return
if weight.data.dtype == torch.uint8:
# already quantized
return
assert isinstance(self.weight, bnb.nn.Params4bit)
self.weight = self.quantize(self.weight, weight, device)

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) ([#19886](https://github.com/Lightning-AI/lightning/pull/19886))


- Fixed `WandbLogger.log_hyperparameters()` raising an error if hyperparameters are not JSON serializable ([#19769](https://github.com/Lightning-AI/pytorch-lightning/pull/19769))
Expand Down
34 changes: 34 additions & 0 deletions tests/tests_fabric/plugins/precision/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,37 @@ def __init__(self):
assert not keys.missing_keys
assert model.l.weight.device.type == "cuda"
assert model.l.weight.dtype == expected


@RunIf(min_cuda_gpus=1, min_torch="2.1")
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
def test_load_quantized_checkpoint(tmp_path):
"""Test that a checkpoint saved from a quantized model can be loaded back into a quantized model."""

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(16, 16, bias=False)

def forward(self, x):
return self.linear(x)

fabric = Fabric(accelerator="cuda", devices=1, plugins=BitsandbytesPrecision("nf4-dq"))
model = Model()
model = fabric.setup(model)
model(torch.randn(2, 16, device=fabric.device))
state_dict = model.state_dict()
# The checkpoint contains quantized weights
assert state_dict["linear.weight"].dtype == torch.uint8
assert state_dict["linear.weight"].shape == (128, 1)
torch.save(state_dict, tmp_path / "checkpoint.pt")

fabric = Fabric(accelerator="cuda", devices=1, plugins=BitsandbytesPrecision("nf4-dq"))
model = Model()
model = fabric.setup(model)
state_dict = torch.load(tmp_path / "checkpoint.pt")
model.load_state_dict(state_dict)
assert model.linear.weight.dtype == torch.uint8
assert model.linear.weight.shape == (128, 1)
# Shapes match during forward (weight is being dequantized during forward)
model(torch.randn(2, 16, device=fabric.device))

0 comments on commit 7e87ce0

Please sign in to comment.