Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add LoRA multihead attention module #1324

Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
49fab86
[WIP] Add LoRA multihead attention module
BenjaminBossan Jan 5, 2024
d8e9589
Make style
BenjaminBossan Jan 5, 2024
0e188a3
Remove commented code
BenjaminBossan Jan 5, 2024
b409d81
Remove assignment of weight to new module
BenjaminBossan Jan 5, 2024
173062c
Make state_dict and named_parameters work
BenjaminBossan Jan 5, 2024
1e007f5
Extend test coverage a bit
BenjaminBossan Jan 8, 2024
557c4a1
Clean ups after reviewer feedback:
BenjaminBossan Jan 9, 2024
add1f51
Reviewer feedback: removed another unnecessary arg
BenjaminBossan Jan 9, 2024
e44e030
Make style
BenjaminBossan Jan 9, 2024
8d62579
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Jan 9, 2024
c5d8a6b
Apply LoRA also to the out_proj of MHA
BenjaminBossan Jan 12, 2024
9dc4a4d
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Feb 7, 2024
c3fb2ce
Fix bug with incorrectly set gradient
BenjaminBossan Feb 7, 2024
17d407b
Fix failing tests
BenjaminBossan Feb 7, 2024
4cbf6e9
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Feb 26, 2024
e0cae11
Move to pytest style asserts
BenjaminBossan Feb 26, 2024
52c8d9b
Fix safe merging code
BenjaminBossan Feb 26, 2024
977c84b
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Mar 11, 2024
96d376d
No need to set bias for MHA anymore, see #1530
BenjaminBossan Mar 11, 2024
0c17476
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Mar 26, 2024
4b8db0c
Fix style
BenjaminBossan Mar 26, 2024
7e91712
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan May 21, 2024
e12070b
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Jul 25, 2024
7b6c7cb
Remove duplicate merge
BenjaminBossan Jul 25, 2024
e6ab8ed
Raise error for multi adapter batch inference
BenjaminBossan Jul 25, 2024
8ec6c3c
Raise error for DoRA + MHA
BenjaminBossan Jul 25, 2024
f6ba465
Fix error when adding multiple adapters to MHA
BenjaminBossan Jul 25, 2024
fb18886
Better way of param initialization
BenjaminBossan Jul 26, 2024
4ff2ec3
Add tests for broken loading and workaround
BenjaminBossan Jul 26, 2024
d1f6ab2
make style
BenjaminBossan Jul 26, 2024
65363be
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Sep 3, 2024
7ba2e68
Fix wrong merge conflict resolution in test
BenjaminBossan Sep 4, 2024
6ef04b0
Ensure that base weights have requires_grad False
BenjaminBossan Sep 4, 2024
07c7240
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Sep 4, 2024
cc3ac3d
Remove xpass-ing test
BenjaminBossan Sep 4, 2024
03c466f
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Sep 12, 2024
e558caa
MAINT: Give stale bot permissions for PRs too (#2064)
BenjaminBossan Sep 12, 2024
38f4a98
ENH BOFT don't save boft_P buffer (#2050)
sywangyi Sep 13, 2024
7e5c61d
FIX Command line args in PiSSA preprocess (#2053)
keakon Sep 13, 2024
183bf52
MNT Update deprecated evaluation_strategy (#1664)
muellerzr Sep 13, 2024
b970607
ENH Multi adapters in same batch: modules_to_save (#1990)
saeid93 Sep 17, 2024
732e8e7
FIX Bug that prevents BOFT from loading 2 adapters (#2068)
BenjaminBossan Sep 18, 2024
79e2b38
TST Skip some quantization tests on XPU (#2074)
faaany Sep 18, 2024
61e6934
Improve test coverage for initialization of MHA
BenjaminBossan Sep 18, 2024
ced2f15
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Oct 14, 2024
4c31bbc
Fix bug with unloading multihead attention layer
BenjaminBossan Oct 21, 2024
1dbb9a5
Fix bug in unloading
BenjaminBossan Oct 22, 2024
e094234
Fix for low_cpu_mem_usage
BenjaminBossan Nov 1, 2024
e90af48
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Nov 1, 2024
30a08e7
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Nov 1, 2024
09f5ea6
Add tests for init_empty_weights
BenjaminBossan Nov 26, 2024
6a83bd7
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Nov 26, 2024
3b0471a
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Dec 9, 2024
465a85e
Add MHA to modules unsupported by EVA
BenjaminBossan Dec 9, 2024
266f9da
Add comment on why/how empty init works
BenjaminBossan Jan 6, 2025
39e755e
Expose attributes of underlying MHA module
BenjaminBossan Jan 6, 2025
4857858
Apply suggestions from code review
BenjaminBossan Jan 6, 2025
74cbba6
Remove trailing whitespace
BenjaminBossan Jan 6, 2025
14deb9f
Linting..
BenjaminBossan Jan 6, 2025
ba2a8dd
Reviewer comment: Add comments for clarification
BenjaminBossan Jan 8, 2025
ac10b18
Reviewer feedback: Remove q_proj_weight
BenjaminBossan Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
in_features, out_features = (
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
)
elif isinstance(base_layer, nn.MultiheadAttention):
if not base_layer._qkv_same_embed_dim:
raise ValueError(f"Only same dim for query/key/value is supported as of now for {self.__class__}.")
in_features, out_features = base_layer.embed_dim, 3 * base_layer.embed_dim
elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"):
# QuantLinear
in_features, out_features = base_layer.infeatures, base_layer.outfeatures
Expand Down Expand Up @@ -688,6 +692,182 @@ def __repr__(self) -> str:
return "lora." + rep


class MultiheadAttention(nn.Module, LoraLayer):
"""LoRA implemented in a multihead attention layer

This is currently only implemented for the case of `_qkv_same_embed_dim = True`, i.e. query, key, and value having
the same dimension.

This is a little bit hacky because of the way that MultiheadAttention is implemented in PyTorch. It works by
merging the weights before the forward call and unmerging them after the forward call.
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
base_layer,
adapter_name: str,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
**kwargs,
) -> None:
# TODO work with separate weights
if not base_layer._qkv_same_embed_dim:
raise ValueError(f"Only same embed for query/key/value is supported as of now for {self.__class__}.")

super().__init__()
LoraLayer.__init__(self, base_layer, **kwargs)

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights

Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
if self.merged:
warnings.warn(
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
f"You are now additionally merging {','.join(self.active_adapters)}."
)

if adapter_names is None:
adapter_names = self.active_adapters

# Implementation follows this:
# https://github.com/Baijiong-Lin/LoRA-Torch/blob/4bfed6820b64fcf47064c30f30606a190a4f0d2e/loratorch/layers.py#L73-L79
# Notably, instead of mutating the weight, we delete the original weight and replace it by the merged weight
# TODO: work with separate weights
for active_adapter in adapter_names:
if active_adapter in self.lora_A.keys():
base_layer = self.get_base_layer()
if safe_merge:
orig_weights = base_layer.in_proj_weight.data.detach().clone()
orig_weights += self.get_delta_weight(active_adapter)

if not torch.isfinite(orig_weights).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)

del base_layer.in_proj_weight
base_layer.in_proj_weight = orig_weights
else:
# TODO: work with separate weights
weight_merged = base_layer.in_proj_weight.data.detach() + self.get_delta_weight(active_adapter)
del base_layer.in_proj_weight
base_layer.in_proj_weight = weight_merged
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this throw an exception? AFAICS we're assigning a tensor to a parameter value:

foo = torch.nn.Linear(10, 100)
foo.weight = foo.weight.detach() # raises

What am I missing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true that we change the type here, I guess you could consider this part of the hack to make this work. At the end, through _restore_weights, the correct type is restored.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, yes. I missed the del statement which unregisters the parameter and, thus, removes the setattr constraint. WDYT about something along the lines of

# unregister parameter implicitly and overwrite using merged weights; gradients are computed
# after forward and, thus, after unmerging (see forward()), therefore this is safe to do.
del base_layer.in_proj_weight
base_layer.in_proj_weight = orig_weights_in

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return

# TODO work with separate weights
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_A.keys():
self.get_base_layer().in_proj_weight.data -= self.get_delta_weight(active_adapter)

def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.

Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.lora_B[adapter].weight.device
dtype = self.lora_B[adapter].weight.dtype

# In case users wants to merge the adapter weights that are in
# float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16.
cast_to_fp32 = device.type == "cpu" and dtype == torch.float16

weight_A = self.lora_A[adapter].weight
weight_B = self.lora_B[adapter].weight

if cast_to_fp32:
weight_A = weight_A.float()
weight_B = weight_B.float()

output_tensor = (weight_B @ weight_A) * self.scaling[adapter]

if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)

# cast back the weights
self.lora_A[adapter].weight.data = weight_A.to(dtype)
self.lora_B[adapter].weight.data = weight_B.to(dtype)

return output_tensor

def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
# merge all adapters that are active for this module
active_adapters = [a for a in self.active_adapters if a in self.lora_A]
try:
self.merge(adapter_names=active_adapters)
result = self.base_layer(x, *args, **kwargs)
finally:
# it's safe to call unmerge(), which unmerges all adapters, because we checked that not self.merged,
# i.e. there is was no merged layer before
self.unmerge()

result = (result[0].to(previous_dtype), result[1].to(previous_dtype) if result[1] is not None else result[1])
return result

def _restore_weights(self):
# Restore the weights as registered parameters on the base layer.
# This is necessary because the way that weights are merged/unmerged (which is necessary for forward to work
# correctly), the Module "forgets" these attributes. Therefore, we need to call register_parameter explicitly.
# We cannot call register_parameter for merging/unmerging because that cuts them off from the autograd graph.
githubnemo marked this conversation as resolved.
Show resolved Hide resolved
# Note that this is hacky, since we need to ensure that _restore_weights is called by each method that needs it.

# TODO work with separate weights
base_layer = self.get_base_layer()
weight = base_layer.in_proj_weight.data
del base_layer.in_proj_weight
base_layer.register_parameter("in_proj_weight", nn.Parameter(weight))

def state_dict(self, *args, **kwargs):
self._restore_weights()
return super().state_dict(*args, **kwargs)

def named_modules(self, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need also to over-write the modules() method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed, as modules calls named_modules under the hood. I added a comment to that effect.

# Note: no need to also implement modules(), as modules() calls named_modules() under the hood
self._restore_weights()
return super().named_modules(*args, **kwargs)

def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep


def dispatch_default(
target: torch.nn.Module,
adapter_name: str,
Expand All @@ -709,6 +889,9 @@ def dispatch_default(
elif isinstance(target_base_layer, torch.nn.Conv2d):
kwargs.update(lora_config.loftq_config)
new_module = Conv2d(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, torch.nn.MultiheadAttention):
kwargs.update(lora_config.loftq_config)
new_module = MultiheadAttention(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, torch.nn.Linear):
if kwargs["fan_in_fan_out"]:
warnings.warn(
Expand Down
19 changes: 12 additions & 7 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,6 @@ def _replace_module(self, parent, child_name, new_module, child):
if hasattr(child, "base_layer"):
child = child.base_layer

if not hasattr(new_module, "base_layer"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this has been removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forgot to put this into the description of the PR.

These lines are obsolete for some time now. They only apply when we unload the model (otherwise, the if does not match). Remember when we made the base_layer switch, we ensured that when unloading, we simply return the base_layer, no more need to create a new layer (say, a new nn.Linear when using lora.Linear) and replace the new layer's weight by the parent layer's weight. The base_layer already has the original weight. Therefore, these lines are unnecessary.

I removed them now because they were annoying with MultiheadAttention, because that layer has no weight attribute, so this line would fail.

new_module.weight = child.weight
if hasattr(child, "bias"):
new_module.bias = child.bias

if getattr(child, "state", None) is not None:
if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
Expand All @@ -203,7 +198,16 @@ def _replace_module(self, parent, child_name, new_module, child):
# dispatch to correct device
for name, module in new_module.named_modules():
if (self.prefix in name) or ("ranknum" in name):
weight = child.qweight if hasattr(child, "qweight") else child.weight
if hasattr(child, "qweight"):
weight = child.qweight
elif hasattr(child, "weight"):
weight = child.weight
elif getattr(child, "in_proj_weight", None) is not None: # MHA
weight = child.in_proj_weight
elif getattr(child, "q_proj_weight", None) is not None: # MHA
weight = child.q_proj_weight
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case we support this is never not None, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean getattr(child, "q_proj_weight", None) is not None can never evaluate to False, thus the else clause below is not needed? I think it would be good to have that fallback, in case we do miss something.

Copy link
Collaborator

@githubnemo githubnemo Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I meant that q_proj_weight is always None in our case. (_qkv_same_embed_dim = True)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, sorry, you're right. This is there in case we add support for the other mode in the future.

else:
raise ValueError(f"Encountered unknown module type: {type(child)}")
module.to(weight.device)

def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
Expand Down Expand Up @@ -256,7 +260,8 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
# no module could be matched
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`."
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`, "
"`torch.nn.MultiheadAttention.`"
)

return new_module
Expand Down
28 changes: 26 additions & 2 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
("Embedding + transformers Conv1D 3 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb", "conv1d"]}),
("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}),
("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}),
("MHA 1 LoRA", "MHA", LoraConfig, {"target_modules": ["mha"]}),
("MHA 1 LoRA", "MHA", LoraConfig, {"target_modules": ["mha", "lin0"]}),
#######
# IA³ #
#######
Expand Down Expand Up @@ -402,6 +404,21 @@ def forward(self, X):
return X


class ModelMha(nn.Module):
def __init__(self):
super().__init__()
self.mha = nn.MultiheadAttention(10, 2)
self.lin0 = nn.Linear(10, 2)
self.sm = nn.LogSoftmax(dim=-1)

def forward(self, X):
X = X.float()
X, _ = self.mha(X, X, X)
X = self.lin0(X)
X = self.sm(X)
return X


class MockTransformerWrapper:
"""Mock class to behave like a transformers model.

Expand All @@ -426,6 +443,9 @@ def from_pretrained(cls, model_id, torch_dtype=None):
if model_id == "Conv2d":
return ModelConv2D().to(torch_dtype)

if model_id == "MHA":
return ModelMha().to(torch_dtype)

raise ValueError(f"model_id {model_id} not implemented")


Expand Down Expand Up @@ -543,7 +563,9 @@ def test_only_params_are_updated(self, test_name, model_id, config_cls, config_k
model_before = copy.deepcopy(model)

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
# we get exploding gradients with MHA when learning rate is too high
lr = 0.5 if "mha" not in model_id.lower() else 1e-3
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
# breaking of some LoRA layers that are initialized with constants)
Expand Down Expand Up @@ -580,7 +602,9 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c
)
model = get_peft_model(model, config)
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
# we get exploding gradients with MHA when learning rate is too high
lr = 0.5 if "mha" not in model_id.lower() else 1e-3
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
# breaking of some LoRA layers that are initialized with constants)
Expand Down
4 changes: 0 additions & 4 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,6 @@ def test_lora_linear_init_gaussian(self):
normal = self.get_normal(0.0, 1 / config.r)
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy())

# import matplotlib.pyplot as plt
# x = weight_A.detach().flatten().cpu().numpy()
# breakpoint()

self.assertGreater(p_value, 0.5)

# check that weight A is *not* from a uniform distribution
Expand Down
Loading