Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix state dict loading in bitsandbytes plugin when checkpoint is already quantized #19886

Merged
merged 5 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) ([#19886](https://github.com/Lightning-AI/lightning/pull/19886))

-

Expand Down
12 changes: 6 additions & 6 deletions src/lightning/fabric/plugins/precision/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torc
"""Inplace quantize."""
if weight is None:
weight = self.weight.data
if weight.data.type == torch.int8:
# already quantized
return
lantiga marked this conversation as resolved.
Show resolved Hide resolved
if weight.data.dtype == torch.int8:
# already quantized
return
assert isinstance(self.weight, bnb.nn.Int8Params)
self.weight = self.quantize(self.weight, weight, device)

Expand Down Expand Up @@ -317,9 +317,9 @@ def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torc
"""Inplace quantize."""
if weight is None:
weight = self.weight.data
if weight.data.type == torch.uint8:
# already quantized
return
if weight.data.dtype == torch.uint8:
# already quantized
return
assert isinstance(self.weight, bnb.nn.Params4bit)
self.weight = self.quantize(self.weight, weight, device)

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) ([#19886](https://github.com/Lightning-AI/lightning/pull/19886))


- Fixed `WandbLogger.log_hyperparameters()` raising an error if hyperparameters are not JSON serializable ([#19769](https://github.com/Lightning-AI/pytorch-lightning/pull/19769))
Expand Down
34 changes: 34 additions & 0 deletions tests/tests_fabric/plugins/precision/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,37 @@ def __init__(self):
assert not keys.missing_keys
assert model.l.weight.device.type == "cuda"
assert model.l.weight.dtype == expected


@RunIf(min_cuda_gpus=1, min_torch="2.1")
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
def test_load_quantized_checkpoint(tmp_path):
"""Test that a checkpoint saved from a quantized model can be loaded back into a quantized model."""

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(16, 16, bias=False)

def forward(self, x):
return self.linear(x)

fabric = Fabric(accelerator="cuda", devices=1, plugins=BitsandbytesPrecision("nf4-dq"))
model = Model()
model = fabric.setup(model)
model(torch.randn(2, 16, device=fabric.device))
state_dict = model.state_dict()
# The checkpoint contains quantized weights
assert state_dict["linear.weight"].dtype == torch.uint8
assert state_dict["linear.weight"].shape == (128, 1)
torch.save(state_dict, tmp_path / "checkpoint.pt")

fabric = Fabric(accelerator="cuda", devices=1, plugins=BitsandbytesPrecision("nf4-dq"))
model = Model()
model = fabric.setup(model)
state_dict = torch.load(tmp_path / "checkpoint.pt")
model.load_state_dict(state_dict)
assert model.linear.weight.dtype == torch.uint8
assert model.linear.weight.shape == (128, 1)
# Shapes match during forward (weight is being dequantized during forward)
model(torch.randn(2, 16, device=fabric.device))
Loading