Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting seq2seq models for bitsandbytes integration #18579

Merged
merged 3 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/transformers/utils/bitsandbytes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from copy import deepcopy

from transformers.utils import is_accelerate_available, is_bitsandbytes_available


Expand All @@ -9,6 +11,7 @@

if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.utils import find_tied_parameters


def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
Expand Down Expand Up @@ -132,8 +135,17 @@ def get_key_to_not_convert(model):
model (`torch.nn.Module`):
Input model
"""
# Create a copy of the model and tie the weights, then
# check if it contains tied weights
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model.tie_weights()
has_tied_params = len(find_tied_parameters(tied_model)) > 0

# Check if it is a base model
is_base_model = not hasattr(model, model.base_model_prefix)

# Ignore this for base models (BertModel, GPT2Model, etc.)
if not hasattr(model, model.base_model_prefix):
if (not has_tied_params) and is_base_model:
return ""

# otherwise they have an attached head
Expand Down
22 changes: 20 additions & 2 deletions tests/mixed_int8/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
import gc
import unittest

from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, pipeline
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoTokenizer,
pipeline,
)
from transformers.testing_utils import (
is_torch_available,
require_accelerate,
Expand Down Expand Up @@ -106,12 +113,21 @@ def setUp(self):
super().setUp()
# model_name
self.model_name = "bigscience/bloom-560m"
# Models and tokenizer
self.seq_to_seq_name = "t5-small"

# Different types of model

self.base_model = AutoModel.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
# Sequence classification model
self.sequence_model = AutoModelForSequenceClassification.from_pretrained(
self.model_name, load_in_8bit=True, device_map="auto"
)
# CausalLM model
self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
# Seq2seq model
self.seq_to_seq_model = AutoModelForSeq2SeqLM.from_pretrained(
self.seq_to_seq_name, load_in_8bit=True, device_map="auto"
)

def tearDown(self):
r"""
Expand All @@ -121,6 +137,7 @@ def tearDown(self):
del self.base_model
del self.sequence_model
del self.model_8bit
del self.seq_to_seq_model

gc.collect()
torch.cuda.empty_cache()
Expand All @@ -138,6 +155,7 @@ def test_correct_head_class(self):
# Other heads should be nn.Parameter
self.assertTrue(self.model_8bit.lm_head.weight.__class__ == torch.nn.Parameter)
self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter)
self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter)


class MixedInt8TestPipeline(BaseMixedInt8Test):
Expand Down