Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Add fp16 + cpu merge support #1017

Merged
merged 6 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 101 additions & 14 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,38 @@ def unmerge(self) -> None:
self.weight.data -= self.get_delta_weight(active_adapter)

def get_delta_weight(self, adapter) -> torch.Tensor:
return (
transpose(
self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
self.fan_in_fan_out,
)
* self.scaling[adapter]
)
"""
Compute the delta weight for the given adapter.

Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.lora_B[adapter].weight.device
dtype = self.lora_B[adapter].weight.dtype

# In case users wants to merge the adapter weights that are in
# float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16.
cast_to_fp32 = device.type == "cpu" and dtype == torch.float16

weight_A = self.lora_A[adapter].weight
weight_B = self.lora_B[adapter].weight

if cast_to_fp32:
weight_A = weight_A.float()
weight_B = weight_B.float()

output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should put all the code starting from here until the return into a try ... finally to ensure that we always cast the weights back to their original dtype, even if there is an error. Might be overkill, not sure.

Too bad there isn't a PyTorch context manager that would allow something like this:

with torch.force_dtype('float32'):
    output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]

Would be quite handy here but I couldn't find anything that would do that.


if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)

# cast back the weights
self.lora_A[adapter].weight.data = weight_A.to(dtype)
self.lora_B[adapter].weight.data = weight_B.to(dtype)

return output_tensor

def _linear(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
Expand Down Expand Up @@ -347,7 +372,38 @@ def unmerge(self) -> None:
self.weight.data -= self.get_delta_weight(active_adapter)

def get_delta_weight(self, adapter) -> torch.Tensor:
return transpose(self.lora_embedding_B[adapter] @ self.lora_embedding_A[adapter], True) * self.scaling[adapter]
"""
Compute the delta weight for the given adapter.

Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.lora_embedding_B[adapter].device
dtype = self.lora_embedding_A[adapter].dtype

# In case users wants to merge the adapter weights that are in
# float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16.
cast_to_fp32 = device.type == "cpu" and dtype == torch.float16

weight_A = self.lora_embedding_A[adapter]
weight_B = self.lora_embedding_B[adapter]

if cast_to_fp32:
weight_A = weight_A.float()
weight_B = weight_B.float()

output_tensor = transpose(weight_B @ weight_A, True) * self.scaling[adapter]

if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)

# cast back the weights
self.lora_embedding_A[adapter] = weight_A.to(dtype)
self.lora_embedding_B[adapter] = weight_B.to(dtype)

return output_tensor

def _embed(self, input: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = self.weight if weight is None else weight
Expand Down Expand Up @@ -455,22 +511,53 @@ def unmerge(self) -> None:
self.weight.data -= self.get_delta_weight(active_adapter)

def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.

Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.lora_B[adapter].weight.device
dtype = self.lora_A[adapter].weight.dtype

# In case users wants to merge the adapter weights that are in
# float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16.
cast_to_fp32 = device.type == "cpu" and dtype == torch.float16

weight_A = self.lora_A[adapter].weight
weight_B = self.lora_B[adapter].weight

if cast_to_fp32:
weight_A = weight_A.float()
weight_B = weight_B.float()

# https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117
if self.weight.size()[2:4] == (1, 1):
# conv2d 1x1
return (
self.lora_B[adapter].weight.squeeze(3).squeeze(2) @ self.lora_A[adapter].weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * self.scaling[adapter]
output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(
3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, the lone 3 is really ugly, thanks black.

) * self.scaling[adapter]
else:
# conv2d 3x3
return (
output_tensor = (
F.conv2d(
self.lora_A[adapter].weight.permute(1, 0, 2, 3),
self.lora_B[adapter].weight,
weight_A.permute(1, 0, 2, 3),
weight_B,
).permute(1, 0, 2, 3)
* self.scaling[adapter]
)

if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)

# cast back the weights
self.lora_A[adapter].weight.data = weight_A.to(dtype)
self.lora_B[adapter].weight.data = weight_B.to(dtype)

return output_tensor

def _conv2d(self, input: torch.Tensor) -> torch.Tensor:
return F.conv2d(
input,
Expand Down
20 changes: 16 additions & 4 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,18 +312,21 @@ class MockTransformerWrapper:
"""

@classmethod
def from_pretrained(cls, model_id):
def from_pretrained(cls, model_id, torch_dtype=None):
# set the seed so that from_pretrained always returns the same model
torch.manual_seed(0)

if torch_dtype is None:
torch_dtype = torch.float32

if model_id == "MLP":
return MLP()
return MLP().to(torch_dtype)

if model_id == "EmbConv1D":
return ModelEmbConv1D()
return ModelEmbConv1D().to(torch_dtype)

if model_id == "Conv2d":
return ModelConv2D()
return ModelConv2D().to(torch_dtype)

raise ValueError(f"model_id {model_id} not implemented")

Expand Down Expand Up @@ -370,6 +373,15 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
config_kwargs["init_ia3_weights"] = False
self._test_merge_layers(model_id, config_cls, config_kwargs)

@parameterized.expand(TEST_CASES)
def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs):
config_kwargs = config_kwargs.copy()
if issubclass(config_cls, LoraConfig):
config_kwargs["init_lora_weights"] = False
elif issubclass(config_cls, IA3Config):
config_kwargs["init_ia3_weights"] = False
self._test_merge_layers_fp16(model_id, config_cls, config_kwargs)

@parameterized.expand(TEST_CASES)
def test_generate(self, test_name, model_id, config_cls, config_kwargs):
# Custom models do not (necessarily) have a generate method, so this test is not performed
Expand Down
4 changes: 4 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs):
def test_generate(self, test_name, model_id, config_cls, config_kwargs):
self._test_generate(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs):
self._test_merge_layers_fp16(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs):
self._test_generate_half_prec(model_id, config_cls, config_kwargs)
Expand Down
20 changes: 20 additions & 0 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,26 @@ def _test_from_pretrained_config_construction(self, model_id, config_cls, config
self.assertTrue(model_from_pretrained.peft_config["default"].inference_mode)
self.assertIs(model_from_pretrained.peft_config["default"], config)

def _test_merge_layers_fp16(self, model_id, config_cls, config_kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small issue with this test: When it is run locally on a cuda-enabled device, it will automatically use the GPU and not CPU, right? Therefore, it wouldn't really test the casting back and forth. On CI, it will run correctly though. Not sure if it's worth fixing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes great point, I have forced the test to happen on CPU

if config_cls not in (LoraConfig,):
# Merge layers only supported for LoRA and IA³
return
if ("gpt2" in model_id.lower()) and (config_cls != LoraConfig):
self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)")

model = self.transformers_class.from_pretrained(model_id, torch_dtype=torch.float16)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model = model.to(device="cpu", dtype=torch.float16)

model.eval()

# This should simply work
_ = model.merge_and_unload()

def _test_merge_layers_nan(self, model_id, config_cls, config_kwargs):
if config_cls not in (LoraConfig, IA3Config, AdaLoraConfig):
# Merge layers only supported for LoRA and IA³
Expand Down
Loading