From 42e25f2ecbe20ef9eb7e1dcaebde4f6030e4d3eb Mon Sep 17 00:00:00 2001 From: His-Wardship <139779341+His-Wardship@users.noreply.github.com> Date: Tue, 19 Sep 2023 14:39:32 +0100 Subject: [PATCH 01/11] Update __init__.py --- src/peft/tuners/ia3/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/peft/tuners/ia3/__init__.py b/src/peft/tuners/ia3/__init__.py index 517ace0b15..04aa42fd5f 100644 --- a/src/peft/tuners/ia3/__init__.py +++ b/src/peft/tuners/ia3/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from peft.import_utils import is_bnb_available +from peft.import_utils import is_bnb_4bit_available, is_bnb_available from .config import IA3Config from .layer import IA3Layer, Linear @@ -27,3 +27,8 @@ from .bnb import Linear8bitLt __all__ += ["Linear8bitLt"] + +if is_bnb_4bit_available(): + from .bnb import Linear4bit + + __all__ += ["Linear4bit"] From d12f8efc6e43b7e820d96b07a3ef60c90855a264 Mon Sep 17 00:00:00 2001 From: His-Wardship <139779341+His-Wardship@users.noreply.github.com> Date: Tue, 19 Sep 2023 14:39:57 +0100 Subject: [PATCH 02/11] Update model.py --- src/peft/tuners/ia3/model.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index 06a8d19684..d8cb8823aa 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -21,7 +21,7 @@ import torch from transformers.pytorch_utils import Conv1D -from peft.import_utils import is_bnb_available +from peft.import_utils import is_bnb_4bit_available, is_bnb_available from peft.tuners.tuners_utils import BaseTuner from peft.utils import ( TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, @@ -39,6 +39,11 @@ from .bnb import Linear8bitLt +if is_bnb_4bit_available(): + import bitsandbytes as bnb + + from .bnb import Linear4bit, Linear8bitLt + class IA3Model(BaseTuner): """ @@ -82,6 +87,7 @@ def __init__(self, model, config, adapter_name): def _create_new_module(ia3_config, adapter_name, target, **kwargs): bias = hasattr(target, "bias") and target.bias is not None loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) + loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) is_feedforward = kwargs.pop("is_feedforward", False) if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): @@ -102,6 +108,23 @@ def _create_new_module(ia3_config, adapter_name, target, **kwargs): bias=bias, **eightbit_kwargs, ) + elif loaded_in_4bit and isinstance(target, bnb.nn.Linear4bit): + fourbit_kwargs = kwargs.copy() + fourbit_kwargs.update( + { + "compute_dtype": target.compute_dtype, + "compress_statistics": target.weight.compress_statistics, + "quant_type": target.weight.quant_type, + } + ) + new_module = Linear4bit( + adapter_name, + target.in_features, + target.out_features, + is_feedforward, + bias=bias, + **fourbit_kwargs, + ) else: # Create a new Linear module with (IA)^3 parameters for torch.nn.Linear # or Conv1D modules @@ -156,6 +179,7 @@ def _create_and_replace( **optionnal_kwargs, ): loaded_in_8bit = optionnal_kwargs["loaded_in_8bit"] + loaded_in_4bit = optionnal_kwargs["loaded_in_4bit"] current_key = optionnal_kwargs["current_key"] # check if target module is in feedforward_modules @@ -168,6 +192,7 @@ def _create_and_replace( "fan_in_fan_out": ia3_config.fan_in_fan_out, "init_ia3_weights": ia3_config.init_ia3_weights, "loaded_in_8bit": loaded_in_8bit, + "loaded_in_4bit": loaded_in_4bit, "is_feedforward": is_feedforward, } From 632bd4b4604f8eef18fd6cab603e94853305abcd Mon Sep 17 00:00:00 2001 From: His-Wardship <139779341+His-Wardship@users.noreply.github.com> Date: Tue, 19 Sep 2023 14:40:19 +0100 Subject: [PATCH 03/11] Update bnb.py --- src/peft/tuners/ia3/bnb.py | 56 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/src/peft/tuners/ia3/bnb.py b/src/peft/tuners/ia3/bnb.py index c88829af0f..13e29e1507 100644 --- a/src/peft/tuners/ia3/bnb.py +++ b/src/peft/tuners/ia3/bnb.py @@ -70,3 +70,59 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: result = result.to(expected_dtype) return result + + +class Linear4bit(bnb.nn.Linear4bit, IA3Layer): + # IA3 implemented in a dense layer + def __init__( + self, + adapter_name, + in_features, + out_features, + is_feedforward, + **kwargs, + ) -> None: + bnb.nn.Linear4bit.__init__( + self, + in_features, + out_features, + bias=kwargs.get("bias", True), + compute_dtype=kwargs.get("compute_dtype", torch.float32), + compress_statistics=kwargs.get("compress_statistics", True), + quant_type=kwargs.get("quant_type", "nf4"), + ) + IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward) + + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + + init_ia3_weights = kwargs.pop("init_ia3_weights", True) + self.update_layer(adapter_name, init_ia3_weights) + self.active_adapter = adapter_name + self.is_feedforward = is_feedforward + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.disable_adapters or (self.active_adapter not in self.ia3_l.keys()): + return super().forward(x) + + requires_conversion = (not torch.is_autocast_enabled()) and (x.dtype != torch.float32) + if requires_conversion: + x = x.float() + + ia3_scaling = self.ia3_l[self.active_adapter].flatten() + if self.is_feedforward: + result = super().forward(x * ia3_scaling) + expected_dtype = result.dtype + else: + result = super().forward(x) + expected_dtype = result.dtype + result = result * ia3_scaling + + result = ( + result.clone() + ) # adalora.py and lora.py both suggested that the inclusion of this was necessary for 4-bit training on older versions of Pytorch. This has been duplicated here. + + if requires_conversion: + result = result.to(expected_dtype) + + return result From d741f9c6a1aef0f4da3115fc6c97b42a0f6dfc62 Mon Sep 17 00:00:00 2001 From: His-Wardship <139779341+His-Wardship@users.noreply.github.com> Date: Tue, 19 Sep 2023 14:41:11 +0100 Subject: [PATCH 04/11] Update test_common_gpu.py Adjusted alias for Lora and IA3 layers --- tests/test_common_gpu.py | 135 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 125 insertions(+), 10 deletions(-) diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index 86fc956fb9..4d590db60f 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -28,7 +28,14 @@ WhisperForConditionalGeneration, ) -from peft import AdaptionPromptConfig, LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training +from peft import ( + AdaptionPromptConfig, + IA3Config, + LoraConfig, + PeftModel, + get_peft_model, + prepare_model_for_kbit_training, +) from peft.import_utils import is_bnb_4bit_available, is_bnb_available from .testing_utils import require_bitsandbytes, require_torch_gpu, require_torch_multi_gpu @@ -37,10 +44,12 @@ if is_bnb_available(): import bitsandbytes as bnb - from peft.tuners.lora import Linear8bitLt + from peft.tuners.ia3 import Linear8bitLt as IA3Linear8bitLt + from peft.tuners.lora import Linear8bitLt as LoraLinear8bitLt if is_bnb_4bit_available(): - from peft.tuners.lora import Linear4bit + from peft.tuners.ia3 import Linear4bit as IA3Linear4bit + from peft.tuners.lora import Linear4bit as LoraLinear4bit @require_torch_gpu @@ -107,14 +116,68 @@ def test_lora_bnb_8bit_quantization(self): config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none") flan_8bit = get_peft_model(flan_8bit, flan_lora_config) - self.assertTrue(isinstance(flan_8bit.base_model.model.encoder.block[0].layer[0].SelfAttention.q, Linear8bitLt)) + self.assertTrue( + isinstance(flan_8bit.base_model.model.encoder.block[0].layer[0].SelfAttention.q, LoraLinear8bitLt) + ) opt_8bit = get_peft_model(opt_8bit, opt_lora_config) - self.assertTrue(isinstance(opt_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, Linear8bitLt)) + self.assertTrue( + isinstance(opt_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear8bitLt) + ) whisper_8bit = get_peft_model(whisper_8bit, config) self.assertTrue( - isinstance(whisper_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, Linear8bitLt) + isinstance(whisper_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear8bitLt) + ) + + @require_bitsandbytes + @pytest.mark.multi_gpu_tests + @pytest.mark.single_gpu_tests + def test_ia3_bnb_8bit_quantization(self): + r""" + Test that tests if the 8bit quantization using IA3 works as expected + """ + whisper_8bit = WhisperForConditionalGeneration.from_pretrained( + self.audio_model_id, + device_map="auto", + load_in_8bit=True, + ) + + opt_8bit = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + load_in_8bit=True, + ) + + flan_8bit = AutoModelForSeq2SeqLM.from_pretrained( + self.seq2seq_model_id, + device_map="auto", + load_in_8bit=True, + ) + + flan_ia3_config = IA3Config(target_modules=["q", "v"], task_type="SEQ_2_SEQ_LM") + + opt_ia3_config = IA3Config( + target_modules=["q_proj", "v_proj"], + feedforward_modules=["down_proj"], + task_type="CAUSAL_LM", + ) + + config = IA3Config(target_modules=["q_proj", "v_proj"], feedforward_modules=["down_proj"]) + + flan_8bit = get_peft_model(flan_8bit, flan_ia3_config) + self.assertTrue( + isinstance(flan_8bit.base_model.model.encoder.block[0].layer[0].SelfAttention.q, IA3Linear8bitLt) + ) + + opt_8bit = get_peft_model(opt_8bit, opt_ia3_config) + self.assertTrue( + isinstance(opt_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, IA3Linear8bitLt) + ) + + whisper_8bit = get_peft_model(whisper_8bit, config) + self.assertTrue( + isinstance(whisper_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, IA3Linear8bitLt) ) @require_bitsandbytes @@ -173,13 +236,65 @@ def test_lora_bnb_4bit_quantization(self): config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none") flan_4bit = get_peft_model(flan_4bit, flan_lora_config) - self.assertTrue(isinstance(flan_4bit.base_model.model.encoder.block[0].layer[0].SelfAttention.q, Linear4bit)) + self.assertTrue( + isinstance(flan_4bit.base_model.model.encoder.block[0].layer[0].SelfAttention.q, LoraLinear4bit) + ) opt_4bit = get_peft_model(opt_4bit, opt_lora_config) - self.assertTrue(isinstance(opt_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, Linear4bit)) + self.assertTrue(isinstance(opt_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear4bit)) + + whisper_4bit = get_peft_model(whisper_4bit, config) + self.assertTrue( + isinstance(whisper_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear4bit) + ) + + @require_bitsandbytes + @pytest.mark.multi_gpu_tests + @pytest.mark.single_gpu_tests + def test_ia3_bnb_4bit_quantization(self): + r""" + Test that tests if the 4bit quantization using IA3 works as expected + """ + whisper_4bit = WhisperForConditionalGeneration.from_pretrained( + self.audio_model_id, + device_map="auto", + load_in_4bit=True, + ) + + opt_4bit = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + load_in_4bit=True, + ) + + flan_4bit = AutoModelForSeq2SeqLM.from_pretrained( + self.seq2seq_model_id, + device_map="auto", + load_in_4bit=True, + ) + + flan_ia3_config = IA3Config(target_modules=["q", "v"], task_type="SEQ_2_SEQ_LM") + + opt_ia3_config = IA3Config( + target_modules=["q_proj", "v_proj"], + feedforward_modules=["down_proj"], + task_type="CAUSAL_LM", + ) + + config = IA3Config(target_modules=["q_proj", "v_proj"], feedforward_modules=["down_proj"]) + + flan_4bit = get_peft_model(flan_4bit, flan_ia3_config) + self.assertTrue( + isinstance(flan_4bit.base_model.model.encoder.block[0].layer[0].SelfAttention.q, IA3Linear4bit) + ) + + opt_4bit = get_peft_model(opt_4bit, opt_ia3_config) + self.assertTrue(isinstance(opt_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, IA3Linear4bit)) whisper_4bit = get_peft_model(whisper_4bit, config) - self.assertTrue(isinstance(whisper_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, Linear4bit)) + self.assertTrue( + isinstance(whisper_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, IA3Linear4bit) + ) @pytest.mark.multi_gpu_tests @require_torch_multi_gpu @@ -228,7 +343,7 @@ def test_lora_seq2seq_lm_mutli_gpu_inference(self): model = get_peft_model(model, lora_config) self.assertTrue(isinstance(model, PeftModel)) - self.assertTrue(isinstance(model.base_model.model.encoder.block[0].layer[0].SelfAttention.q, Linear8bitLt)) + self.assertTrue(isinstance(model.base_model.model.encoder.block[0].layer[0].SelfAttention.q, LoraLinear8bitLt)) dummy_input = "This is a dummy input:" input_ids = tokenizer(dummy_input, return_tensors="pt").input_ids.to(self.device) From 5de3122414dd73b783056c635d97faf7882034a1 Mon Sep 17 00:00:00 2001 From: His-Wardship <139779341+His-Wardship@users.noreply.github.com> Date: Wed, 20 Sep 2023 16:31:18 +0100 Subject: [PATCH 05/11] Update bnb.py Add guard to check for BNB --- src/peft/tuners/ia3/bnb.py | 208 +++++++++++++++++++------------------ 1 file changed, 107 insertions(+), 101 deletions(-) diff --git a/src/peft/tuners/ia3/bnb.py b/src/peft/tuners/ia3/bnb.py index 13e29e1507..9a2c73c7e6 100644 --- a/src/peft/tuners/ia3/bnb.py +++ b/src/peft/tuners/ia3/bnb.py @@ -16,113 +16,119 @@ import bitsandbytes as bnb import torch +from peft.import_utils import is_bnb_4bit_available, is_bnb_available + from .layer import IA3Layer -class Linear8bitLt(bnb.nn.Linear8bitLt, IA3Layer): - # (IA)^3 implemented in a dense layer - def __init__( - self, - adapter_name, - in_features, - out_features, - is_feedforward, - **kwargs, - ) -> None: - bnb.nn.Linear8bitLt.__init__( +if is_bnb_available(): + + class Linear8bitLt(bnb.nn.Linear8bitLt, IA3Layer): + # (IA)^3 implemented in a dense layer + def __init__( self, + adapter_name, in_features, out_features, - bias=kwargs.get("bias", True), - has_fp16_weights=kwargs.get("has_fp16_weights", True), - memory_efficient_backward=kwargs.get("memory_efficient_backward", False), - threshold=kwargs.get("threshold", 0.0), - index=kwargs.get("index", None), - ) - IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward) - - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - - init_ia3_weights = kwargs.pop("init_ia3_weights", True) - self.update_layer(adapter_name, init_ia3_weights) - self.active_adapter = adapter_name - self.is_feedforward = is_feedforward - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.disable_adapters or (self.active_adapter not in self.ia3_l.keys()): - return super().forward(x) - - requires_conversion = (not torch.is_autocast_enabled()) and (x.dtype != torch.float32) - if requires_conversion: - x = x.float() - - ia3_scaling = self.ia3_l[self.active_adapter].flatten() - if self.is_feedforward: - result = super().forward(x * ia3_scaling) - expected_dtype = result.dtype - else: - result = super().forward(x) - expected_dtype = result.dtype - result = result * ia3_scaling - - if requires_conversion: - result = result.to(expected_dtype) - - return result - - -class Linear4bit(bnb.nn.Linear4bit, IA3Layer): - # IA3 implemented in a dense layer - def __init__( - self, - adapter_name, - in_features, - out_features, - is_feedforward, - **kwargs, - ) -> None: - bnb.nn.Linear4bit.__init__( + is_feedforward, + **kwargs, + ) -> None: + bnb.nn.Linear8bitLt.__init__( + self, + in_features, + out_features, + bias=kwargs.get("bias", True), + has_fp16_weights=kwargs.get("has_fp16_weights", True), + memory_efficient_backward=kwargs.get("memory_efficient_backward", False), + threshold=kwargs.get("threshold", 0.0), + index=kwargs.get("index", None), + ) + IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward) + + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + + init_ia3_weights = kwargs.pop("init_ia3_weights", True) + self.update_layer(adapter_name, init_ia3_weights) + self.active_adapter = adapter_name + self.is_feedforward = is_feedforward + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.disable_adapters or (self.active_adapter not in self.ia3_l.keys()): + return super().forward(x) + + requires_conversion = (not torch.is_autocast_enabled()) and (x.dtype != torch.float32) + if requires_conversion: + x = x.float() + + ia3_scaling = self.ia3_l[self.active_adapter].flatten() + if self.is_feedforward: + result = super().forward(x * ia3_scaling) + expected_dtype = result.dtype + else: + result = super().forward(x) + expected_dtype = result.dtype + result = result * ia3_scaling + + if requires_conversion: + result = result.to(expected_dtype) + + return result + + +if is_bnb_4bit_available(): + + class Linear4bit(bnb.nn.Linear4bit, IA3Layer): + # IA3 implemented in a dense layer + def __init__( self, + adapter_name, in_features, out_features, - bias=kwargs.get("bias", True), - compute_dtype=kwargs.get("compute_dtype", torch.float32), - compress_statistics=kwargs.get("compress_statistics", True), - quant_type=kwargs.get("quant_type", "nf4"), - ) - IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward) - - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - - init_ia3_weights = kwargs.pop("init_ia3_weights", True) - self.update_layer(adapter_name, init_ia3_weights) - self.active_adapter = adapter_name - self.is_feedforward = is_feedforward - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.disable_adapters or (self.active_adapter not in self.ia3_l.keys()): - return super().forward(x) - - requires_conversion = (not torch.is_autocast_enabled()) and (x.dtype != torch.float32) - if requires_conversion: - x = x.float() - - ia3_scaling = self.ia3_l[self.active_adapter].flatten() - if self.is_feedforward: - result = super().forward(x * ia3_scaling) - expected_dtype = result.dtype - else: - result = super().forward(x) - expected_dtype = result.dtype - result = result * ia3_scaling - - result = ( - result.clone() - ) # adalora.py and lora.py both suggested that the inclusion of this was necessary for 4-bit training on older versions of Pytorch. This has been duplicated here. - - if requires_conversion: - result = result.to(expected_dtype) - - return result + is_feedforward, + **kwargs, + ) -> None: + bnb.nn.Linear4bit.__init__( + self, + in_features, + out_features, + bias=kwargs.get("bias", True), + compute_dtype=kwargs.get("compute_dtype", torch.float32), + compress_statistics=kwargs.get("compress_statistics", True), + quant_type=kwargs.get("quant_type", "nf4"), + ) + IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward) + + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + + init_ia3_weights = kwargs.pop("init_ia3_weights", True) + self.update_layer(adapter_name, init_ia3_weights) + self.active_adapter = adapter_name + self.is_feedforward = is_feedforward + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.disable_adapters or (self.active_adapter not in self.ia3_l.keys()): + return super().forward(x) + + requires_conversion = (not torch.is_autocast_enabled()) and (x.dtype != torch.float32) + if requires_conversion: + x = x.float() + + ia3_scaling = self.ia3_l[self.active_adapter].flatten() + if self.is_feedforward: + result = super().forward(x * ia3_scaling) + expected_dtype = result.dtype + else: + result = super().forward(x) + expected_dtype = result.dtype + result = result * ia3_scaling + + result = ( + result.clone() + ) # adalora.py and lora.py both suggested that the inclusion of this was necessary for 4-bit training on older versions of Pytorch. This has been duplicated here. + + if requires_conversion: + result = result.to(expected_dtype) + + return result From 301d500d181eecd31afa633e8f78f4ca48e33285 Mon Sep 17 00:00:00 2001 From: His-Wardship <139779341+His-Wardship@users.noreply.github.com> Date: Thu, 21 Sep 2023 18:23:25 +0100 Subject: [PATCH 06/11] Update model.py Add error for merging in 4-bit --- src/peft/tuners/ia3/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index d8cb8823aa..6b515be3a3 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -281,6 +281,9 @@ def merge_and_unload(self): if getattr(self.model, "is_loaded_in_8bit", False): raise ValueError("Cannot merge ia3 layers when the model is loaded in 8-bit mode") + if getattr(self.model, "is_loaded_in_4bit", False): + raise ValueError("Cannot merge ia3 layers when the model is loaded in 4-bit mode") + key_list = [key for key, _ in self.model.named_modules() if "ia3" not in key] for key in key_list: try: From cc3d737c70a8a8aff2a0e231022c6675b4e56380 Mon Sep 17 00:00:00 2001 From: His-Wardship <139779341+His-Wardship@users.noreply.github.com> Date: Mon, 25 Sep 2023 12:49:32 +0100 Subject: [PATCH 07/11] Update bnb.py Resolve conflicts with PR #873 --- src/peft/tuners/ia3/bnb.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/peft/tuners/ia3/bnb.py b/src/peft/tuners/ia3/bnb.py index 9a2c73c7e6..1f9800e7cc 100644 --- a/src/peft/tuners/ia3/bnb.py +++ b/src/peft/tuners/ia3/bnb.py @@ -54,14 +54,18 @@ def __init__( self.is_feedforward = is_feedforward def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.disable_adapters or (self.active_adapter not in self.ia3_l.keys()): + if self.disable_adapters: return super().forward(x) + ia3_scaling = 1 + for active_adapter in self.active_adapters: + if active_adapter not in self.ia3_l.keys(): + continue + ia3_scaling *= self.ia3_l[active_adapter].flatten() + requires_conversion = (not torch.is_autocast_enabled()) and (x.dtype != torch.float32) if requires_conversion: x = x.float() - - ia3_scaling = self.ia3_l[self.active_adapter].flatten() if self.is_feedforward: result = super().forward(x * ia3_scaling) expected_dtype = result.dtype @@ -108,14 +112,18 @@ def __init__( self.is_feedforward = is_feedforward def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.disable_adapters or (self.active_adapter not in self.ia3_l.keys()): + if self.disable_adapters: return super().forward(x) + ia3_scaling = 1 + for active_adapter in self.active_adapters: + if active_adapter not in self.ia3_l.keys(): + continue + ia3_scaling *= self.ia3_l[active_adapter].flatten() + requires_conversion = (not torch.is_autocast_enabled()) and (x.dtype != torch.float32) if requires_conversion: x = x.float() - - ia3_scaling = self.ia3_l[self.active_adapter].flatten() if self.is_feedforward: result = super().forward(x * ia3_scaling) expected_dtype = result.dtype From 55eaad1139264978aa0fcda9cccf38e7782b9a7f Mon Sep 17 00:00:00 2001 From: His-Wardship <139779341+His-Wardship@users.noreply.github.com> Date: Mon, 25 Sep 2023 12:51:41 +0100 Subject: [PATCH 08/11] Update model.py Resolve conflicts with PR #873 --- src/peft/tuners/ia3/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index 6b515be3a3..bb08b215d8 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -176,11 +176,11 @@ def _create_and_replace( target, target_name, parent, - **optionnal_kwargs, + **optional_kwargs, ): - loaded_in_8bit = optionnal_kwargs["loaded_in_8bit"] - loaded_in_4bit = optionnal_kwargs["loaded_in_4bit"] - current_key = optionnal_kwargs["current_key"] + loaded_in_8bit = optional_kwargs["loaded_in_8bit"] + loaded_in_4bit = optional_kwargs["loaded_in_4bit"] + current_key = optional_kwargs["current_key"] # check if target module is in feedforward_modules if isinstance(ia3_config.feedforward_modules, str): From b98dc86139ded827012bce3d743fe2b6abe7f641 Mon Sep 17 00:00:00 2001 From: His-Wardship <139779341+His-Wardship@users.noreply.github.com> Date: Mon, 25 Sep 2023 13:44:21 +0100 Subject: [PATCH 09/11] Update bnb.py Remove accidental blank indent for make quality test --- src/peft/tuners/ia3/bnb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/ia3/bnb.py b/src/peft/tuners/ia3/bnb.py index c1d9a76794..9eb913874f 100644 --- a/src/peft/tuners/ia3/bnb.py +++ b/src/peft/tuners/ia3/bnb.py @@ -120,7 +120,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if active_adapter not in self.ia3_l.keys(): continue ia3_scaling *= self.ia3_l[active_adapter].flatten() - + requires_conversion = (not torch.is_autocast_enabled()) and (x.dtype != torch.float32) if requires_conversion: x = x.float() @@ -139,4 +139,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if requires_conversion: result = result.to(expected_dtype) - return result \ No newline at end of file + return result From be320f970bbca31c60b12eb707615e7b4fc28f8b Mon Sep 17 00:00:00 2001 From: His-Wardship <139779341+His-Wardship@users.noreply.github.com> Date: Tue, 26 Sep 2023 11:23:46 +0100 Subject: [PATCH 10/11] Update bnb.py Resolve conflicts introduced by PR #905 --- src/peft/tuners/ia3/bnb.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/peft/tuners/ia3/bnb.py b/src/peft/tuners/ia3/bnb.py index 9eb913874f..ce9567a9e2 100644 --- a/src/peft/tuners/ia3/bnb.py +++ b/src/peft/tuners/ia3/bnb.py @@ -44,14 +44,14 @@ def __init__( index=kwargs.get("index", None), ) IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward) + self.is_feedforward = is_feedforward # Freezing the pre-trained weight matrix self.weight.requires_grad = False init_ia3_weights = kwargs.pop("init_ia3_weights", True) self.update_layer(adapter_name, init_ia3_weights) - self.active_adapter = adapter_name - self.is_feedforward = is_feedforward + self.set_adapter(adapter_name) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.disable_adapters: @@ -102,14 +102,14 @@ def __init__( quant_type=kwargs.get("quant_type", "nf4"), ) IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward) + self.is_feedforward = is_feedforward # Freezing the pre-trained weight matrix self.weight.requires_grad = False init_ia3_weights = kwargs.pop("init_ia3_weights", True) self.update_layer(adapter_name, init_ia3_weights) - self.active_adapter = adapter_name - self.is_feedforward = is_feedforward + self.set_adapter(adapter_name) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.disable_adapters: From b42777310ffa92bcdf205ad75d2a1190602f187a Mon Sep 17 00:00:00 2001 From: His-Wardship <139779341+His-Wardship@users.noreply.github.com> Date: Tue, 26 Sep 2023 12:14:04 +0100 Subject: [PATCH 11/11] Update bnb.py Re-arrange to remove unnecessary parenthesis. --- src/peft/tuners/ia3/bnb.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/peft/tuners/ia3/bnb.py b/src/peft/tuners/ia3/bnb.py index 3eb2060e5d..2aa37c1d5c 100644 --- a/src/peft/tuners/ia3/bnb.py +++ b/src/peft/tuners/ia3/bnb.py @@ -132,11 +132,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: expected_dtype = result.dtype result = result * ia3_scaling - result = ( - result.clone() - ) # adalora.py and lora.py both suggested that the inclusion of this was necessary for 4-bit training on older versions of Pytorch. This has been duplicated here. + result = result.clone() + # adalora.py and lora.py both suggest that this is necessary for 4-bit training on older versions of Pytorch. + # This has been duplicated here. if requires_conversion: result = result.to(expected_dtype) - return result \ No newline at end of file + return result