From aaa7e9f44a6405af819e721d7ee7fc6dd190c980 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 13 Oct 2023 12:23:16 +0200 Subject: [PATCH] FEAT: Add fp16 + cpu merge support (#1017) * add fp16 + cpu merge support * fix tests * add fp16 tests for custom models * fix tests * adapt from comments * more clarifications --- src/peft/tuners/lora/layer.py | 115 +++++++++++++++++++++++++++++----- tests/test_custom_models.py | 20 ++++-- tests/test_decoder_models.py | 4 ++ tests/testing_common.py | 20 ++++++ 4 files changed, 141 insertions(+), 18 deletions(-) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index d78ab498db..8796b6a638 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -250,13 +250,38 @@ def unmerge(self) -> None: self.weight.data -= self.get_delta_weight(active_adapter) def get_delta_weight(self, adapter) -> torch.Tensor: - return ( - transpose( - self.lora_B[adapter].weight @ self.lora_A[adapter].weight, - self.fan_in_fan_out, - ) - * self.scaling[adapter] - ) + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + device = self.lora_B[adapter].weight.device + dtype = self.lora_B[adapter].weight.dtype + + # In case users wants to merge the adapter weights that are in + # float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16. + cast_to_fp32 = device.type == "cpu" and dtype == torch.float16 + + weight_A = self.lora_A[adapter].weight + weight_B = self.lora_B[adapter].weight + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + + output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter] + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.lora_A[adapter].weight.data = weight_A.to(dtype) + self.lora_B[adapter].weight.data = weight_B.to(dtype) + + return output_tensor def _linear(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) @@ -347,7 +372,38 @@ def unmerge(self) -> None: self.weight.data -= self.get_delta_weight(active_adapter) def get_delta_weight(self, adapter) -> torch.Tensor: - return transpose(self.lora_embedding_B[adapter] @ self.lora_embedding_A[adapter], True) * self.scaling[adapter] + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + device = self.lora_embedding_B[adapter].device + dtype = self.lora_embedding_A[adapter].dtype + + # In case users wants to merge the adapter weights that are in + # float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16. + cast_to_fp32 = device.type == "cpu" and dtype == torch.float16 + + weight_A = self.lora_embedding_A[adapter] + weight_B = self.lora_embedding_B[adapter] + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + + output_tensor = transpose(weight_B @ weight_A, True) * self.scaling[adapter] + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.lora_embedding_A[adapter] = weight_A.to(dtype) + self.lora_embedding_B[adapter] = weight_B.to(dtype) + + return output_tensor def _embed(self, input: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor: weight = self.weight if weight is None else weight @@ -455,22 +511,53 @@ def unmerge(self) -> None: self.weight.data -= self.get_delta_weight(active_adapter) def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + device = self.lora_B[adapter].weight.device + dtype = self.lora_A[adapter].weight.dtype + + # In case users wants to merge the adapter weights that are in + # float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16. + cast_to_fp32 = device.type == "cpu" and dtype == torch.float16 + + weight_A = self.lora_A[adapter].weight + weight_B = self.lora_B[adapter].weight + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117 if self.weight.size()[2:4] == (1, 1): # conv2d 1x1 - return ( - self.lora_B[adapter].weight.squeeze(3).squeeze(2) @ self.lora_A[adapter].weight.squeeze(3).squeeze(2) - ).unsqueeze(2).unsqueeze(3) * self.scaling[adapter] + output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze( + 3 + ) * self.scaling[adapter] else: # conv2d 3x3 - return ( + output_tensor = ( F.conv2d( - self.lora_A[adapter].weight.permute(1, 0, 2, 3), - self.lora_B[adapter].weight, + weight_A.permute(1, 0, 2, 3), + weight_B, ).permute(1, 0, 2, 3) * self.scaling[adapter] ) + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.lora_A[adapter].weight.data = weight_A.to(dtype) + self.lora_B[adapter].weight.data = weight_B.to(dtype) + + return output_tensor + def _conv2d(self, input: torch.Tensor) -> torch.Tensor: return F.conv2d( input, diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 51568919d0..f5ad3ee838 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -312,18 +312,21 @@ class MockTransformerWrapper: """ @classmethod - def from_pretrained(cls, model_id): + def from_pretrained(cls, model_id, torch_dtype=None): # set the seed so that from_pretrained always returns the same model torch.manual_seed(0) + if torch_dtype is None: + torch_dtype = torch.float32 + if model_id == "MLP": - return MLP() + return MLP().to(torch_dtype) if model_id == "EmbConv1D": - return ModelEmbConv1D() + return ModelEmbConv1D().to(torch_dtype) if model_id == "Conv2d": - return ModelConv2D() + return ModelConv2D().to(torch_dtype) raise ValueError(f"model_id {model_id} not implemented") @@ -370,6 +373,15 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): config_kwargs["init_ia3_weights"] = False self._test_merge_layers(model_id, config_cls, config_kwargs) + @parameterized.expand(TEST_CASES) + def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs): + config_kwargs = config_kwargs.copy() + if issubclass(config_cls, LoraConfig): + config_kwargs["init_lora_weights"] = False + elif issubclass(config_cls, IA3Config): + config_kwargs["init_ia3_weights"] = False + self._test_merge_layers_fp16(model_id, config_cls, config_kwargs) + @parameterized.expand(TEST_CASES) def test_generate(self, test_name, model_id, config_cls, config_kwargs): # Custom models do not (necessarily) have a generate method, so this test is not performed diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index b6efc5da43..ea30a8183c 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -118,6 +118,10 @@ def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): def test_generate(self, test_name, model_id, config_cls, config_kwargs): self._test_generate(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs): + self._test_merge_layers_fp16(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs): self._test_generate_half_prec(model_id, config_cls, config_kwargs) diff --git a/tests/testing_common.py b/tests/testing_common.py index a69ac1f86d..e4ad873883 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -430,6 +430,26 @@ def _test_from_pretrained_config_construction(self, model_id, config_cls, config self.assertTrue(model_from_pretrained.peft_config["default"].inference_mode) self.assertIs(model_from_pretrained.peft_config["default"], config) + def _test_merge_layers_fp16(self, model_id, config_cls, config_kwargs): + if config_cls not in (LoraConfig,): + # Merge layers only supported for LoRA and IA³ + return + if ("gpt2" in model_id.lower()) and (config_cls != LoraConfig): + self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)") + + model = self.transformers_class.from_pretrained(model_id, torch_dtype=torch.float16) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + model = model.to(device="cpu", dtype=torch.float16) + + model.eval() + + # This should simply work + _ = model.merge_and_unload() + def _test_merge_layers_nan(self, model_id, config_cls, config_kwargs): if config_cls not in (LoraConfig, IA3Config, AdaLoraConfig): # Merge layers only supported for LoRA and IA³