From d22502997960455def67885aa03bca77170ab61b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 11 Aug 2022 12:55:56 +0000 Subject: [PATCH 1/2] Supporting seq2seq models for `bitsandbytes` integration - `bitsandbytes` integration supports now seq2seq models - check if a model has tied weights as an additional check --- src/transformers/utils/bitsandbytes.py | 10 +++++++++- tests/mixed_int8/test_mixed_int8.py | 22 ++++++++++++++++++++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index ee4e52d421fd09..d18d240372bd3d 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -9,6 +9,7 @@ if is_accelerate_available(): from accelerate import init_empty_weights + from accelerate.utils import find_tied_parameters def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): @@ -132,8 +133,15 @@ def get_key_to_not_convert(model): model (`torch.nn.Module`): Input model """ + # Check if the model has tied parameters + # This will keep the lm_head etc on their original class type + has_tied_params = len(find_tied_parameters(model)) > 0 + + # Check if it is a base model + is_base_model = not hasattr(model, model.base_model_prefix) + # Ignore this for base models (BertModel, GPT2Model, etc.) - if not hasattr(model, model.base_model_prefix): + if (not has_tied_params) and is_base_model: return "" # otherwise they have an attached head diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 0cd7ca16411c19..2911d67748809a 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -15,7 +15,14 @@ import gc import unittest -from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, pipeline +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoTokenizer, + pipeline, +) from transformers.testing_utils import ( is_torch_available, require_accelerate, @@ -106,12 +113,21 @@ def setUp(self): super().setUp() # model_name self.model_name = "bigscience/bloom-560m" - # Models and tokenizer + self.seq_to_seq_name = "t5-small" + + # Different types of model + self.base_model = AutoModel.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + # Sequence classification model self.sequence_model = AutoModelForSequenceClassification.from_pretrained( self.model_name, load_in_8bit=True, device_map="auto" ) + # CausalLM model self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto") + # Seq2seq model + self.seq_to_seq_model = AutoModelForSeq2SeqLM.from_pretrained( + self.seq_to_seq_name, load_in_8bit=True, device_map="auto" + ) def tearDown(self): r""" @@ -121,6 +137,7 @@ def tearDown(self): del self.base_model del self.sequence_model del self.model_8bit + del self.seq_to_seq_model gc.collect() torch.cuda.empty_cache() @@ -138,6 +155,7 @@ def test_correct_head_class(self): # Other heads should be nn.Parameter self.assertTrue(self.model_8bit.lm_head.weight.__class__ == torch.nn.Parameter) self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter) + self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter) class MixedInt8TestPipeline(BaseMixedInt8Test): From efd78bc739a6ed16890904ccf1f0b5e585ccc389 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 11 Aug 2022 13:55:24 +0000 Subject: [PATCH 2/2] small modification - tie the weights before looking at tied weights! --- src/transformers/utils/bitsandbytes.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index d18d240372bd3d..eca605b2edef94 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -1,3 +1,5 @@ +from copy import deepcopy + from transformers.utils import is_accelerate_available, is_bitsandbytes_available @@ -133,9 +135,11 @@ def get_key_to_not_convert(model): model (`torch.nn.Module`): Input model """ - # Check if the model has tied parameters - # This will keep the lm_head etc on their original class type - has_tied_params = len(find_tied_parameters(model)) > 0 + # Create a copy of the model and tie the weights, then + # check if it contains tied weights + tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` + tied_model.tie_weights() + has_tied_params = len(find_tied_parameters(tied_model)) > 0 # Check if it is a base model is_base_model = not hasattr(model, model.base_model_prefix)