Skip to content

Commit

Permalink
FEAT: Add fp16 + cpu merge support (#1017)
Browse files Browse the repository at this point in the history
* add fp16 + cpu merge support

* fix tests

* add fp16 tests for custom models

* fix tests

* adapt from comments

* more clarifications
  • Loading branch information
younesbelkada authored Oct 13, 2023
1 parent 07f2b82 commit aaa7e9f
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 18 deletions.
115 changes: 101 additions & 14 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,38 @@ def unmerge(self) -> None:
self.weight.data -= self.get_delta_weight(active_adapter)

def get_delta_weight(self, adapter) -> torch.Tensor:
return (
transpose(
self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
self.fan_in_fan_out,
)
* self.scaling[adapter]
)
"""
Compute the delta weight for the given adapter.
Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.lora_B[adapter].weight.device
dtype = self.lora_B[adapter].weight.dtype

# In case users wants to merge the adapter weights that are in
# float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16.
cast_to_fp32 = device.type == "cpu" and dtype == torch.float16

weight_A = self.lora_A[adapter].weight
weight_B = self.lora_B[adapter].weight

if cast_to_fp32:
weight_A = weight_A.float()
weight_B = weight_B.float()

output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]

if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)

# cast back the weights
self.lora_A[adapter].weight.data = weight_A.to(dtype)
self.lora_B[adapter].weight.data = weight_B.to(dtype)

return output_tensor

def _linear(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
Expand Down Expand Up @@ -347,7 +372,38 @@ def unmerge(self) -> None:
self.weight.data -= self.get_delta_weight(active_adapter)

def get_delta_weight(self, adapter) -> torch.Tensor:
return transpose(self.lora_embedding_B[adapter] @ self.lora_embedding_A[adapter], True) * self.scaling[adapter]
"""
Compute the delta weight for the given adapter.
Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.lora_embedding_B[adapter].device
dtype = self.lora_embedding_A[adapter].dtype

# In case users wants to merge the adapter weights that are in
# float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16.
cast_to_fp32 = device.type == "cpu" and dtype == torch.float16

weight_A = self.lora_embedding_A[adapter]
weight_B = self.lora_embedding_B[adapter]

if cast_to_fp32:
weight_A = weight_A.float()
weight_B = weight_B.float()

output_tensor = transpose(weight_B @ weight_A, True) * self.scaling[adapter]

if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)

# cast back the weights
self.lora_embedding_A[adapter] = weight_A.to(dtype)
self.lora_embedding_B[adapter] = weight_B.to(dtype)

return output_tensor

def _embed(self, input: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = self.weight if weight is None else weight
Expand Down Expand Up @@ -455,22 +511,53 @@ def unmerge(self) -> None:
self.weight.data -= self.get_delta_weight(active_adapter)

def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.
Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.lora_B[adapter].weight.device
dtype = self.lora_A[adapter].weight.dtype

# In case users wants to merge the adapter weights that are in
# float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16.
cast_to_fp32 = device.type == "cpu" and dtype == torch.float16

weight_A = self.lora_A[adapter].weight
weight_B = self.lora_B[adapter].weight

if cast_to_fp32:
weight_A = weight_A.float()
weight_B = weight_B.float()

# https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117
if self.weight.size()[2:4] == (1, 1):
# conv2d 1x1
return (
self.lora_B[adapter].weight.squeeze(3).squeeze(2) @ self.lora_A[adapter].weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * self.scaling[adapter]
output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(
3
) * self.scaling[adapter]
else:
# conv2d 3x3
return (
output_tensor = (
F.conv2d(
self.lora_A[adapter].weight.permute(1, 0, 2, 3),
self.lora_B[adapter].weight,
weight_A.permute(1, 0, 2, 3),
weight_B,
).permute(1, 0, 2, 3)
* self.scaling[adapter]
)

if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)

# cast back the weights
self.lora_A[adapter].weight.data = weight_A.to(dtype)
self.lora_B[adapter].weight.data = weight_B.to(dtype)

return output_tensor

def _conv2d(self, input: torch.Tensor) -> torch.Tensor:
return F.conv2d(
input,
Expand Down
20 changes: 16 additions & 4 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,18 +312,21 @@ class MockTransformerWrapper:
"""

@classmethod
def from_pretrained(cls, model_id):
def from_pretrained(cls, model_id, torch_dtype=None):
# set the seed so that from_pretrained always returns the same model
torch.manual_seed(0)

if torch_dtype is None:
torch_dtype = torch.float32

if model_id == "MLP":
return MLP()
return MLP().to(torch_dtype)

if model_id == "EmbConv1D":
return ModelEmbConv1D()
return ModelEmbConv1D().to(torch_dtype)

if model_id == "Conv2d":
return ModelConv2D()
return ModelConv2D().to(torch_dtype)

raise ValueError(f"model_id {model_id} not implemented")

Expand Down Expand Up @@ -370,6 +373,15 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
config_kwargs["init_ia3_weights"] = False
self._test_merge_layers(model_id, config_cls, config_kwargs)

@parameterized.expand(TEST_CASES)
def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs):
config_kwargs = config_kwargs.copy()
if issubclass(config_cls, LoraConfig):
config_kwargs["init_lora_weights"] = False
elif issubclass(config_cls, IA3Config):
config_kwargs["init_ia3_weights"] = False
self._test_merge_layers_fp16(model_id, config_cls, config_kwargs)

@parameterized.expand(TEST_CASES)
def test_generate(self, test_name, model_id, config_cls, config_kwargs):
# Custom models do not (necessarily) have a generate method, so this test is not performed
Expand Down
4 changes: 4 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs):
def test_generate(self, test_name, model_id, config_cls, config_kwargs):
self._test_generate(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs):
self._test_merge_layers_fp16(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs):
self._test_generate_half_prec(model_id, config_cls, config_kwargs)
Expand Down
20 changes: 20 additions & 0 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,26 @@ def _test_from_pretrained_config_construction(self, model_id, config_cls, config
self.assertTrue(model_from_pretrained.peft_config["default"].inference_mode)
self.assertIs(model_from_pretrained.peft_config["default"], config)

def _test_merge_layers_fp16(self, model_id, config_cls, config_kwargs):
if config_cls not in (LoraConfig,):
# Merge layers only supported for LoRA and IA³
return
if ("gpt2" in model_id.lower()) and (config_cls != LoraConfig):
self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)")

model = self.transformers_class.from_pretrained(model_id, torch_dtype=torch.float16)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model = model.to(device="cpu", dtype=torch.float16)

model.eval()

# This should simply work
_ = model.merge_and_unload()

def _test_merge_layers_nan(self, model_id, config_cls, config_kwargs):
if config_cls not in (LoraConfig, IA3Config, AdaLoraConfig):
# Merge layers only supported for LoRA and IA³
Expand Down

0 comments on commit aaa7e9f

Please sign in to comment.