Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX [quantization / ESM] Fix ESM 8bit / 4bit with bitsandbytes #29329

Merged
merged 3 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/models/esm/modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def forward(
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = torch.matmul(attention_probs, value_layer)
context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed to perform correctly inference otherwise you get dtype mismatch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do we get if we don't do this fix ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You get a dtype mismatch :/


context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def check_quantized_param(
import bitsandbytes as bnb

module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
if tensor_name in module._parameters and isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
# Add here check for loaded components' dtypes once serialization is implemented
return True
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/quantizers/quantizer_bnb_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def check_quantized_param(
import bitsandbytes as bnb

module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters[tensor_name], bnb.nn.Int8Params):
if tensor_name in module._parameters and isinstance(module._parameters[tensor_name], bnb.nn.Int8Params):
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
if self.pre_quantized:
if param_name.replace("weight", "SCB") not in state_dict.keys():
raise ValueError("Missing quantization component `SCB`")
Expand Down
20 changes: 17 additions & 3 deletions tests/models/esm/test_modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import unittest

from transformers import EsmConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
from transformers.testing_utils import TestCasePlus, require_bitsandbytes, require_torch, slow, torch_device

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
Expand Down Expand Up @@ -303,9 +303,9 @@ def test_resize_tokens_embeddings(self):
pass


@slow
@require_torch
class EsmModelIntegrationTest(TestCasePlus):
@slow
def test_inference_masked_lm(self):
with torch.no_grad():
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
Expand All @@ -323,7 +323,6 @@ def test_inference_masked_lm(self):
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))

@slow
def test_inference_no_head(self):
with torch.no_grad():
model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
Expand All @@ -336,3 +335,18 @@ def test_inference_no_head(self):
[[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))

@require_bitsandbytes
def test_inference_bitsandbytes(self):
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_8bit=True)

input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
# Just test if inference works
with torch.no_grad():
_ = model(input_ids)[0]

model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_4bit=True)

input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
# Just test if inference works
_ = model(input_ids)[0]
Loading