-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A (possible) bug in lora merging method: add_weighted_adapter #1155
Comments
Thank you for reporting the issue. At first glance, I thought the current implementation should be correct, because import torch
from torch import nn
from peft import LoraConfig, get_peft_model
class MLP(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.lin0 = nn.Linear(10, 20, bias=bias)
self.relu = nn.ReLU()
self.drop = nn.Dropout(0.5)
self.lin1 = nn.Linear(20, 2, bias=bias)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = X.float()
X = self.lin0(X)
X = self.relu(X)
X = self.drop(X)
X = self.lin1(X)
X = self.sm(X)
return X
torch.manual_seed(0)
inputs = torch.arange(90).view(9, 10)
config = LoraConfig(target_modules=["lin0"], init_lora_weights=False)
model = MLP()
model = get_peft_model(model, config, "adapter0")
model.add_adapter("adapter1", config)
model = model.eval()
assert model.active_adapters == ["adapter0"]
output0 = model(inputs)
model.set_adapter("adapter1")
assert model.active_adapters == ["adapter1"]
output1 = model(inputs)
assert not torch.allclose(output0, output1)
model.base_model.set_adapter(["adapter0", "adapter1"])
assert model.active_adapters == ["adapter0", "adapter1"]
output_both = model(inputs)
assert not torch.allclose(output0, output_both)
assert not torch.allclose(output1, output_both)
model.add_weighted_adapter(
["adapter0", "adapter1"], weights=[1, 0], adapter_name="weight10", combination_type="linear"
)
model.set_adapter("weight10")
assert model.active_adapters == ["weight10"]
output_weighted10 = model(inputs)
assert torch.allclose(output0, output_weighted10)
model.add_weighted_adapter(
["adapter0", "adapter1"], weights=[0, 1], adapter_name="weight01", combination_type="linear"
)
model.set_adapter("weight01")
assert model.active_adapters == ["weight01"]
output_weighted01 = model(inputs)
assert torch.allclose(output1, output_weighted01) This test fails with PEFT currently but it succeeds if this line: peft/src/peft/tuners/lora/model.py Line 521 in 0ae52fe
is modified to target_lora_B.data += current_adapter_lora_B.data * weight * target.scaling[adapter] @pacman100 could you please take a look? When I extended a test a bit with these lines: model.add_weighted_adapter(
["adapter0", "adapter1"], weights=[1, 1], adapter_name="weight11", combination_type="linear"
)
model.set_adapter("weight11")
assert model.active_adapters == ["weight11"]
output_weighted11 = model(inputs)
assert torch.allclose(output_both, output_weighted11) both variants failed. Note: @jihuishan For the future, please add permalinks to the code in question, as line numbers change constantly, so we don't know what line of code you're referring to. |
@BenjaminBossan Thank you for your response and patience in reproducing the results, and I'm sorry for the confusing way of referring the codes (me've never issued a question before.) However, this will be a little bit different due to the way we merge adapters. Lets say the weights are From another perspective, any Due to reasons above, my fix of adding weight to B works and merely works for extreme cases where weights are like [0,1,0], and it seems quite tough to produce a merging that is truely independent. Still, adding a weight to B is of some help, making the performance much better than without it when merging adapters in my experiments, though it is not identical to the literal meaning of "average". |
Thanks for explaining. Indeed, the two formulas are not equivalent and there is in fact no way to combine That being said, I can still see an argument for applying the |
Hello, for exact equivalence, please use
For this reason, I multiplied only with lora_A. Earlier, the implementation multiple both with lora_A and lora_B.
Yes, that seems reasonable. |
Thank you for the detailed explanation, makes sense |
That's a good idea, I presume.
Thank you! I will try 'svd'. |
My go-to setting is |
Indeed, |
System Info
peft 0.6.1
Who can help?
No response
Information
Tasks
examples
folderReproduction
In both line 565 and 578 of add_weighted_adapter from, peft/tuners/lora/model.py, lora_A and lora_B are upgraded by
target_lora_A.data += current_adapter_lora_A.data * weight * target.scaling[adapter]
target_lora_B.data += current_adapter_lora_B.data
It is quite confusing (and performance degrading) that the upgrade of target_lora_B is not relevant with the weight, although this seems like a delibrate design (as target_lora_A is upgraded correctly).
Expected behavior
When using add_weighted_adapter, me expect all params of each adapter is weightedly combined, instead of weightedly averaging across lora_A while evenly averaging for lora_B.
During my experiments on merging several loras for a unified model, the performance drops signicantly if the problem above is not solved. For example, lets say there are five adapters with same configs are trained individually on five datasets, and when using a set of weights: [1,0,0,0,0] for add_weighted_adapter to merge these adatpers, me expect the result to be the same of using the first adapter without averaging, which is how weighted averaging works, right? But actually, there is a huge gap between the performance of these two implementations, which should have been almost the same.
A simple fix is to change both line 565 and 578 by adding the weight:
target_lora_B.data += current_adapter_lora_B.data ---> target_lora_B.data += current_adapter_lora_B.data * weight
And it works. Assigning a set of weights of [1,0,0,0,0] for five adapters works nearly the same as merely using the first adapter now.
Thank you for your amazing work for the community, and the patience reading this.
The text was updated successfully, but these errors were encountered: