diff --git a/src/peft/tuners/lora/bnb.py b/src/peft/tuners/lora/bnb.py index d00e2d0a0d..4bd8151ed3 100644 --- a/src/peft/tuners/lora/bnb.py +++ b/src/peft/tuners/lora/bnb.py @@ -26,32 +26,21 @@ if is_bnb_available(): - class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer): + class Linear8bitLt(torch.nn.Module, LoraLayer): # Lora implemented in a dense layer def __init__( self, adapter_name, - in_features, - out_features, + base_layer, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, **kwargs, ) -> None: - bnb.nn.Linear8bitLt.__init__( - self, - in_features, - out_features, - bias=kwargs.get("bias", True), - has_fp16_weights=kwargs.get("has_fp16_weights", True), - memory_efficient_backward=kwargs.get("memory_efficient_backward", False), - threshold=kwargs.get("threshold", 0.0), - index=kwargs.get("index", None), - ) - LoraLayer.__init__(self, in_features=in_features, out_features=out_features) + super().__init__() + LoraLayer.__init__(self, in_features=base_layer.in_features, out_features=base_layer.out_features) + self.base_layer = base_layer - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False init_lora_weights = kwargs.pop("init_lora_weights", True) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self.set_adapter(adapter_name) @@ -71,6 +60,7 @@ def merge(self, safe_merge: bool = False): f"Already following adapters were merged {','.join(self.merged_adapters)}. " f"You are now additionally merging {','.join(self.active_adapters)}." ) + for active_adapter in self.active_adapters: if active_adapter not in self.lora_A.keys(): continue @@ -79,37 +69,38 @@ def merge(self, safe_merge: bool = False): ) lora_data = self.get_delta_weight(active_adapter) - if self.state.SCB is None: - self.state.SCB = self.weight.SCB + weight = self.base_layer.weight + state = self.base_layer.state + if state.SCB is None: + state.SCB = weight.SCB + # Dequantize the result of identity matrix and int8 weight because bitsandbytes does not support int8 # dequantization directly - im = torch.eye(self.weight.data.shape[-1]).contiguous().half().to(self.weight.device) + im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device) im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im) im, Sim = bnb.functional.transform(im, "col32") + if state.CxB is None: + state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB) + out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB) + output = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() - if self.state.CxB is None: - self.state.CxB, self.state.SB = bnb.functional.transform( - self.weight.data, to_order=self.state.formatB - ) - out32, Sout32 = bnb.functional.igemmlt(im, self.state.CxB, Sim, self.state.SB) - output = bnb.functional.mm_dequant(out32, Sout32, SCim, self.state.SCB, bias=None).t() w_data = output.to(lora_data.dtype).to(lora_data.device) + lora_data - if safe_merge and not torch.isfinite(w_data).all(): raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - self.weight = bnb.nn.Int8Params( - w_data.to("cpu"), requires_grad=False, has_fp16_weights=self.weight.has_fp16_weights - ).to(self.weight.device) - self.state.reset_grads() + self.base_layer.weight = bnb.nn.Int8Params( + w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights + ).to(weight.device) + state.reset_grads() self.merged_adapters.append(active_adapter) def unmerge(self): if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return + while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter not in self.lora_A.keys(): @@ -119,23 +110,24 @@ def unmerge(self): ) lora_data = self.get_delta_weight(active_adapter) - if self.state.SCB is None: - self.state.SCB = self.weight.SCB - im = torch.eye(self.weight.data.shape[-1]).contiguous().half().to(self.weight.device) + weight = self.base_layer.weight + state = self.base_layer.state + if state.SCB is None: + state.SCB = weight.SCB + im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device) im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im) im, Sim = bnb.functional.transform(im, "col32") - if self.state.CxB is None: - self.state.CxB, self.state.SB = bnb.functional.transform( - self.weight.data, to_order=self.state.formatB - ) - out32, Sout32 = bnb.functional.igemmlt(im, self.state.CxB, Sim, self.state.SB) - output = bnb.functional.mm_dequant(out32, Sout32, SCim, self.state.SCB, bias=None).t() + if state.CxB is None: + state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB) + out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB) + output = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() + w_data = output.to(lora_data.dtype).to(lora_data.device) - lora_data - self.weight = bnb.nn.Int8Params( - w_data.to("cpu"), requires_grad=False, has_fp16_weights=self.weight.has_fp16_weights - ).to(self.weight.device) - self.state.reset_grads() + self.base_layer.weight = bnb.nn.Int8Params( + w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights + ).to(weight.device) + state.reset_grads() def get_delta_weight(self, adapter): return ( @@ -146,15 +138,15 @@ def get_delta_weight(self, adapter): * self.scaling[adapter] ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: if self.disable_adapters: if self.merged: self.unmerge() - result = super().forward(x) + result = self.base_layer(x, *args, **kwargs) elif self.merged: - result = super().forward(x) + result = self.base_layer(x, *args, **kwargs) else: - result = super().forward(x) + result = self.base_layer(x, *args, **kwargs) for active_adapter in self.active_adapters: if active_adapter not in self.lora_A.keys(): continue @@ -180,31 +172,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if is_bnb_4bit_available(): - class Linear4bit(bnb.nn.Linear4bit, LoraLayer): + class Linear4bit(torch.nn.Module, LoraLayer): # Lora implemented in a dense layer def __init__( self, adapter_name, - in_features, - out_features, + base_layer, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, **kwargs, ) -> None: - bnb.nn.Linear4bit.__init__( - self, - in_features, - out_features, - bias=kwargs.get("bias", True), - compute_dtype=kwargs.get("compute_dtype", torch.float32), - compress_statistics=kwargs.get("compress_statistics", True), - quant_type=kwargs.get("quant_type", "nf4"), - ) - LoraLayer.__init__(self, in_features=in_features, out_features=out_features) - - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False + super().__init__() + LoraLayer.__init__(self, in_features=base_layer.in_features, out_features=base_layer.out_features) + self.base_layer = base_layer init_lora_weights = kwargs.pop("init_lora_weights", True) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) @@ -225,6 +206,7 @@ def merge(self, safe_merge: bool = False): f"Already following adapters were merged {','.join(self.merged_adapters)}. " f"You are now additionally merging {','.join(self.active_adapters)}." ) + for active_adapter in self.active_adapters: if active_adapter not in self.lora_A.keys(): continue @@ -232,21 +214,26 @@ def merge(self, safe_merge: bool = False): "Merge lora module to 4-bit linear may get different generations due to rounding errors." ) # Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 - kwargs = self.weight.__dict__ + weight = self.base_layer.weight + kwargs = weight.__dict__ lora_data = self.get_delta_weight(active_adapter) - w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) + lora_data + w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) + lora_data if safe_merge and not torch.isfinite(w_data).all(): raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - self.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(self.weight.device) + + self.base_layer.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( + weight.device + ) self.merged_adapters.append(active_adapter) def unmerge(self): if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return + while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter not in self.lora_A.keys(): @@ -254,10 +241,13 @@ def unmerge(self): warnings.warn( "Unmerge lora module to 4-bit linear may get different generations due to rounding errors." ) - kwargs = self.weight.__dict__ + weight = self.base_layer.weight + kwargs = weight.__dict__ lora_data = self.get_delta_weight(active_adapter) - w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) - lora_data - self.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(self.weight.device) + w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - lora_data + self.base_layer.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( + weight.device + ) def get_delta_weight(self, adapter): return ( @@ -268,15 +258,15 @@ def get_delta_weight(self, adapter): * self.scaling[adapter] ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: if self.disable_adapters: if self.merged: self.unmerge() - result = super().forward(x) + result = self.base_layer.forward(x, *args, **kwargs) elif self.merged: - result = super().forward(x) + result = self.base_layer.forward(x, *args, **kwargs) else: - result = super().forward(x) + result = self.base_layer.forward(x, *args, **kwargs) # As per Tim Dettmers, for 4bit, we need to defensively clone here. # The reason is that in some cases, an error can occur that backprop # does not work on a manipulated view. This issue may be solved with diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index f1b33d07e6..873eb25a2a 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -201,12 +201,24 @@ def _replace_module(parent, child_name, new_module, child): setattr(parent, child_name, new_module) # It's not necessary to set requires_grad here, as that is handled by # _mark_only_adapters_as_trainable - new_module.weight = child.weight - if hasattr(child, "bias"): - new_module.bias = child.bias + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + elif hasattr(child, "quant_linear_module"): + child = child.quant_linear_module + + # TODO: layers with base_layer don't need the weight to be copied, as they have a reference already + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias if getattr(child, "state", None) is not None: - new_module.state = child.state + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state new_module.to(child.weight.device) # dispatch to correct device @@ -256,9 +268,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): "index": target.index, } ) - new_module = Linear8bitLt( - adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs - ) + new_module = Linear8bitLt(adapter_name, target, **eightbit_kwargs) elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit): fourbit_kwargs = kwargs.copy() fourbit_kwargs.update( @@ -268,7 +278,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): "quant_type": target.weight.quant_type, } ) - new_module = Linear4bit(adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs) + new_module = Linear4bit(adapter_name, target, **fourbit_kwargs) elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear): new_module = QuantLinear(adapter_name, target, **kwargs) target.weight = target.qweight @@ -389,28 +399,28 @@ def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False, sa padding=target.padding, dilation=target.dilation, ) - elif is_bnb_available() and isinstance(target, bnb.nn.Linear8bitLt): - bias = target.bias is not None + elif is_bnb_available() and isinstance(target, Linear8bitLt): + bias = target.base_layer.bias is not None new_module = bnb.nn.Linear8bitLt( target.in_features, target.out_features, bias=bias, - has_fp16_weights=target.state.has_fp16_weights, - memory_efficient_backward=target.state.memory_efficient_backward, - threshold=target.state.threshold, - index=target.index, - device=target.weight.device, + has_fp16_weights=target.base_layer.state.has_fp16_weights, + memory_efficient_backward=target.base_layer.state.memory_efficient_backward, + threshold=target.base_layer.state.threshold, + index=target.base_layer.index, + device=target.base_layer.weight.device, ) - elif is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit): - bias = target.bias is not None + elif is_bnb_4bit_available() and isinstance(target, Linear4bit): + bias = target.base_layer.bias is not None new_module = bnb.nn.Linear4bit( target.in_features, target.out_features, bias=bias, - compute_dtype=target.compute_dtype, - compress_statistics=target.weight.compress_statistics, - quant_type=target.weight.quant_type, - device=target.weight.device, + compute_dtype=target.base_layer.compute_dtype, + compress_statistics=target.base_layer.weight.compress_statistics, + quant_type=target.base_layer.weight.quant_type, + device=target.base_layer.weight.device, ) else: bias = target.bias is not None diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index 4d590db60f..f329600f94 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -545,12 +545,8 @@ def test_8bit_merge_and_disable_lora(self): self.assertFalse(torch.allclose(out_base, out_before, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(out_base, out_after, atol=atol, rtol=rtol)) self.assertTrue(isinstance(model, PeftModel)) - self.assertTrue( - isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear8bitLt) - ) - self.assertTrue( - isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, bnb.nn.Linear8bitLt) - ) + self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear8bitLt)) + self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear8bitLt)) @require_torch_gpu @pytest.mark.single_gpu_tests @@ -633,5 +629,5 @@ def test_4bit_merge_and_disable_lora(self): self.assertFalse(torch.allclose(out_base, out_before, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(out_base, out_after, atol=atol, rtol=rtol)) self.assertTrue(isinstance(model, PeftModel)) - self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear4bit)) - self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, bnb.nn.Linear4bit)) + self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear4bit)) + self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear4bit))