Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A (possible) bug in lora merging method: add_weighted_adapter #1155

Closed
2 of 4 tasks
jihuishan opened this issue Nov 21, 2023 · 8 comments · Fixed by #1169
Closed
2 of 4 tasks

A (possible) bug in lora merging method: add_weighted_adapter #1155

jihuishan opened this issue Nov 21, 2023 · 8 comments · Fixed by #1169

Comments

@jihuishan
Copy link
Contributor

jihuishan commented Nov 21, 2023

System Info

peft 0.6.1

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

In both line 565 and 578 of add_weighted_adapter from, peft/tuners/lora/model.py, lora_A and lora_B are upgraded by
target_lora_A.data += current_adapter_lora_A.data * weight * target.scaling[adapter]
target_lora_B.data += current_adapter_lora_B.data
It is quite confusing (and performance degrading) that the upgrade of target_lora_B is not relevant with the weight, although this seems like a delibrate design (as target_lora_A is upgraded correctly).

Expected behavior

When using add_weighted_adapter, me expect all params of each adapter is weightedly combined, instead of weightedly averaging across lora_A while evenly averaging for lora_B.
During my experiments on merging several loras for a unified model, the performance drops signicantly if the problem above is not solved. For example, lets say there are five adapters with same configs are trained individually on five datasets, and when using a set of weights: [1,0,0,0,0] for add_weighted_adapter to merge these adatpers, me expect the result to be the same of using the first adapter without averaging, which is how weighted averaging works, right? But actually, there is a huge gap between the performance of these two implementations, which should have been almost the same.
A simple fix is to change both line 565 and 578 by adding the weight:
target_lora_B.data += current_adapter_lora_B.data ---> target_lora_B.data += current_adapter_lora_B.data * weight
And it works. Assigning a set of weights of [1,0,0,0,0] for five adapters works nearly the same as merely using the first adapter now.
Thank you for your amazing work for the community, and the patience reading this.

@BenjaminBossan
Copy link
Member

Thank you for reporting the issue.

At first glance, I thought the current implementation should be correct, because lora_A and lora_B are multiplied, so it should be enough for the weight to be applied to one of the two. However, I could write a test that fails with the current implementation but succeeds when multiplying the weights to lora_B too:

import torch
from torch import nn
from peft import LoraConfig, get_peft_model

class MLP(nn.Module):
    def __init__(self, bias=True):
        super().__init__()
        self.lin0 = nn.Linear(10, 20, bias=bias)
        self.relu = nn.ReLU()
        self.drop = nn.Dropout(0.5)
        self.lin1 = nn.Linear(20, 2, bias=bias)
        self.sm = nn.LogSoftmax(dim=-1)

    def forward(self, X):
        X = X.float()
        X = self.lin0(X)
        X = self.relu(X)
        X = self.drop(X)
        X = self.lin1(X)
        X = self.sm(X)
        return X

torch.manual_seed(0)
inputs = torch.arange(90).view(9, 10)
config = LoraConfig(target_modules=["lin0"], init_lora_weights=False)
model = MLP()
model = get_peft_model(model, config, "adapter0")
model.add_adapter("adapter1", config)
model = model.eval()

assert model.active_adapters == ["adapter0"]
output0 = model(inputs)

model.set_adapter("adapter1")
assert model.active_adapters == ["adapter1"]
output1 = model(inputs)

assert not torch.allclose(output0, output1)

model.base_model.set_adapter(["adapter0", "adapter1"])
assert model.active_adapters == ["adapter0", "adapter1"]
output_both = model(inputs)

assert not torch.allclose(output0, output_both)
assert not torch.allclose(output1, output_both)

model.add_weighted_adapter(
    ["adapter0", "adapter1"], weights=[1, 0], adapter_name="weight10", combination_type="linear"
)
model.set_adapter("weight10")
assert model.active_adapters == ["weight10"]
output_weighted10 = model(inputs)
assert torch.allclose(output0, output_weighted10)

model.add_weighted_adapter(
    ["adapter0", "adapter1"], weights=[0, 1], adapter_name="weight01", combination_type="linear"
)
model.set_adapter("weight01")
assert model.active_adapters == ["weight01"]
output_weighted01 = model(inputs)
assert torch.allclose(output1, output_weighted01)

This test fails with PEFT currently but it succeeds if this line:

target_lora_B.data += current_adapter_lora_B.data

is modified to

target_lora_B.data += current_adapter_lora_B.data * weight * target.scaling[adapter]

@pacman100 could you please take a look?

When I extended a test a bit with these lines:

model.add_weighted_adapter(
    ["adapter0", "adapter1"], weights=[1, 1], adapter_name="weight11", combination_type="linear"
)
model.set_adapter("weight11")
assert model.active_adapters == ["weight11"]
output_weighted11 = model(inputs)
assert torch.allclose(output_both, output_weighted11)

both variants failed.

Note: @jihuishan For the future, please add permalinks to the code in question, as line numbers change constantly, so we don't know what line of code you're referring to.

@jihuishan
Copy link
Contributor Author

@BenjaminBossan Thank you for your response and patience in reproducing the results, and I'm sorry for the confusing way of referring the codes (me've never issued a question before.)
There may be an explanation about the problem. When merging LoRA adapters, we hope that: $W = W_0 + αB_1A_1 + βB_2A_2 + ……$
Therefore, it is understandable that it was implemented with no weights on B, because due to the formula above, simply putting a weight on A is enough.

However, this will be a little bit different due to the way we merge adapters. Lets say the weights are $[1,0,0]$ for three adapters, then:
$$W = W_0 + BA = W_0 + (B_1 + B_2 + B_3)(αA_1 + βA_2 + γA_3) = W_0 +(B_1 + B_2 + B_3)(1\times A_1 + 0\times A_2 + 0\times A_3) = W_0 + (B_1 + B_2 + B_3)A_1$$
In another word, the problem is that the B here is being averaged before multiplying with A. Therefore, no matter how the weights are, the arithmetic mean of B_1, B_2 and B_3 is actually used, instead of the intended one.

From another perspective, any $B_i$ and $A_i$ of an adatper for merging are coupling with other adapters like $B_j$ and $A_j$, which are not independent, and not what it looks like in the first formula above.

Due to reasons above, my fix of adding weight to B works and merely works for extreme cases where weights are like [0,1,0], and it seems quite tough to produce a merging that is truely independent. Still, adding a weight to B is of some help, making the performance much better than without it when merging adapters in my experiments, though it is not identical to the literal meaning of "average".

@BenjaminBossan
Copy link
Member

Thanks for explaining. Indeed, the two formulas are not equivalent and there is in fact no way to combine lora_A and lora_B to make the two equivalent, except for trivial cases (like weights 0 or 1). Therefore, when using that method, we have to accept that we will only get an approximation of what we would get when adding the LoRAs one after another. SVD might be the preferred method.

That being said, I can still see an argument for applying the weight to both lora_A and lora_B so that we can get identical results for edge cases like weights=[1, 0, 0]. However, let's say we have weights [0.5, 0.5]. If we apply 0.5 to both sides, we actually shrink the total weight norm because it is applied twice, right? So I wonder if we should actually apply sqrt(weight) to both sides instead. WDYT @jihuishan and @pacman100?

@pacman100
Copy link
Contributor

pacman100 commented Nov 22, 2023

Hello, for exact equivalence, please use cat and for approx equivalence use svd. linear is a way of combining LoRAs which doesn't conform to merging them and is a simple weighted linear combination.

However, let's say we have weights [0.5, 0.5]. If we apply 0.5 to both sides, we actually shrink the total weight norm because it is applied twice, right?

For this reason, I multiplied only with lora_A. Earlier, the implementation multiple both with lora_A and lora_B.

So I wonder if we should actually apply sqrt(weight) to both sides instead.

Yes, that seems reasonable.

@pacman100
Copy link
Contributor

However, this will be a little bit different due to the way we merge adapters. Lets say the weights are
for three adapters, then:

In another word, the problem is that the B here is being averaged before multiplying with A. Therefore, no matter how the weights are, the arithmetic mean of B_1, B_2 and B_3 is actually used, instead of the intended one.

Thank you for the detailed explanation, makes sense

@jihuishan
Copy link
Contributor Author

So I wonder if we should actually apply sqrt(weight) to both sides instead.

That's a good idea, I presume.

Hello, for exact equivalence, please use cat and for approx equivalence use svd. linear is a way of combining LoRAs which doesn't conform to merging them and is a simple weighted linear combination.

Thank you! I will try 'svd'.

@pacman100
Copy link
Contributor

pacman100 commented Nov 22, 2023

My go-to setting is cat as it is exact and faster when ranks are small and not merging many adapters.

@jihuishan
Copy link
Contributor Author

My go-to setting is cat as it is exact and faster when ranks are small and not merging many adapters.

Indeed, cat seems more convenient and stable here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants