Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bitsandbytes support for gpt2 models #24504

Merged
merged 5 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions src/transformers/utils/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch
import torch.nn as nn

from ..pytorch_utils import Conv1D

if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.utils import find_tied_parameters
Expand Down Expand Up @@ -84,6 +86,11 @@ class `Int8Params` from `bitsandbytes`.
else:
new_value = torch.tensor(value, device="cpu")

# Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization.
# Since weights are saved in the correct "orientation", we skip transposing when loading.
if issubclass(module.source_cls, Conv1D) and fp16_statistics is None:
new_value = new_value.T

kwargs = old_value.__dict__
if is_8bit:
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
Expand Down Expand Up @@ -122,14 +129,20 @@ def _replace_with_bnb_linear(
current_key_name = []
current_key_name.append(name)

if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
if (isinstance(module, nn.Linear) or isinstance(module, Conv1D)) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights():
if isinstance(module, Conv1D):
in_features, out_features = module.weight.shape
else:
in_features = module.in_features
out_features = module.out_features

if quantization_config.quantization_method() == "llm_int8":
model._modules[name] = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
in_features,
out_features,
module.bias is not None,
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
threshold=quantization_config.llm_int8_threshold,
Expand All @@ -143,14 +156,16 @@ def _replace_with_bnb_linear(
pass
else:
model._modules[name] = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
in_features,
out_features,
module.bias is not None,
quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type,
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
Expand Down Expand Up @@ -200,7 +215,6 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
Expand Down
15 changes: 14 additions & 1 deletion tests/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@
from transformers.utils.versions import importlib_metadata


def get_some_linear_layer(model):
if model.config.model_type == "gpt2":
return model.transformer.h[0].mlp.c_fc
return model.transformer.h[0].mlp.dense_4h_to_h


if is_torch_available():
import torch
import torch.nn as nn
Expand Down Expand Up @@ -83,6 +89,7 @@ class Base4bitTest(unittest.TestCase):
EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n")
EXPECTED_OUTPUTS.add("Hello my name is John Doe, I am a student at the University")
MAX_NEW_TOKENS = 10

def setUp(self):
Expand Down Expand Up @@ -135,7 +142,8 @@ def test_memory_footprint(self):
mem_4bit = self.model_4bit.get_memory_footprint()

self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE)
self.assertTrue(self.model_4bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Params4bit)
linear = get_some_linear_layer(self.model_4bit)
self.assertTrue(linear.weight.__class__ == Params4bit)

def test_linear_are_4bit(self):
r"""
Expand Down Expand Up @@ -473,3 +481,8 @@ def test_training(self):
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
elif isinstance(module, nn.Embedding):
self.assertTrue(module.weight.grad is None)


class Bnb4BitGPT2Test(Bnb4BitTest):
model_name = "gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187
33 changes: 26 additions & 7 deletions tests/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@
from transformers.utils.versions import importlib_metadata


def get_some_linear_layer(model):
if model.config.model_type == "gpt2":
return model.transformer.h[0].mlp.c_fc
return model.transformer.h[0].mlp.dense_4h_to_h


if is_accelerate_available():
from accelerate import PartialState
from accelerate.logging import get_logger
Expand Down Expand Up @@ -142,7 +148,7 @@ def test_memory_footprint(self):
mem_8bit = self.model_8bit.get_memory_footprint()

self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE)
self.assertTrue(self.model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.assertTrue(get_some_linear_layer(self.model_8bit).weight.__class__ == Int8Params)

def test_linear_are_8bit(self):
r"""
Expand Down Expand Up @@ -292,8 +298,9 @@ def test_int8_serialization(self):

model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")

self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
linear = get_some_linear_layer(model_from_saved)
self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))

# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
Expand All @@ -318,8 +325,9 @@ def test_int8_serialization_sharded(self):

model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname)

self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
linear = get_some_linear_layer(model_from_saved)
self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))

# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
Expand All @@ -339,8 +347,9 @@ def test_int8_from_pretrained(self):

model = AutoModelForCausalLM.from_pretrained(model_id)

self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.assertTrue(hasattr(model.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
linear = get_some_linear_layer(model)
self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))

# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
Expand Down Expand Up @@ -748,3 +757,13 @@ def test_training(self):
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
elif isinstance(module, nn.Embedding):
self.assertTrue(module.weight.grad is None)


class MixedInt8GPT2Test(MixedInt8Test):
model_name = "gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357
EXPECTED_OUTPUT = "Hello my name is John Doe, and I am a member of the"

def test_int8_from_pretrained(self):
# TODO @younesbelkada: Test loading quantized gpt2 model from the hub.
pass