-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Combine multiple (IA)^3 Adapters and delete (IA)^3 adapters #980
Conversation
Please take a look at our contribution guide. |
@BenjaminBossan Thanks for the tip! I did look at the contribution guide and ran the tests and style commands. My question was more related to where to add useful tests, but I think I figured it out. Let me know if there are other tests I should add! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding this feature. Sorry that it took longer to review. I have encountered a couple of issues, could please take a look? Thanks.
src/peft/tuners/ia3/model.py
Outdated
target_modules_type = type(self.peft_config[adapters[0]].target_modules) | ||
new_target_modules = set() if target_modules_type == list else "" | ||
feedforward_modules_type = type(self.peft_config[adapters[0]].feedforward_modules) | ||
new_feedforward_modules = set() if feedforward_modules_type == list else "" | ||
for adapter in adapters: | ||
if type(self.peft_config[adapter].target_modules) != target_modules_type: | ||
raise ValueError( | ||
"all adapter configs should follow the same target modules type. " | ||
"Combining adapters with `target_modules` type being a mix of list and string is not supported." | ||
) | ||
if target_modules_type == list: | ||
new_target_modules |= set(self.peft_config[adapter].target_modules) | ||
else: | ||
new_target_modules += f"({self.peft_config[adapter].target_modules})|" | ||
|
||
if type(self.peft_config[adapter].feedforward_modules) != feedforward_modules_type: | ||
raise ValueError( | ||
"all adapter configs should follow the same feedforward modules type. " | ||
"Combining adapters with `feedforward_modules` type being a mix of list and string is not supported." | ||
) | ||
if feedforward_modules_type == list: | ||
new_feedforward_modules |= set(self.peft_config[adapter].feedforward_modules) | ||
else: | ||
new_feedforward_modules += f"({self.peft_config[adapter].feedforward_modules})|" | ||
|
||
new_target_modules = list(new_target_modules) if target_modules_type == list else new_target_modules[:-1] | ||
new_feedforward_modules = ( | ||
list(new_feedforward_modules) if target_modules_type == list else new_feedforward_modules[:-1] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The whole logic has been refactored a bit for LoRA in #993 since you created this PR:
peft/src/peft/tuners/lora/model.py
Lines 499 to 515 in 0c16918
target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters] | |
if not target_module_types: | |
raise ValueError(f"Found no adapter matching the names in {adapters}") | |
if len(set(target_module_types)) > 1: | |
raise ValueError( | |
"all adapter configs should follow the same target modules type. " | |
"Combining adapters with `target_modules` type being a mix of list/set and string is not supported." | |
) | |
if target_module_types[0] == str: | |
new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters) | |
elif target_module_types[0] == set: | |
new_target_modules = reduce( | |
operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters) | |
) | |
else: | |
raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules") |
Could you please adopt those changes here for consistency? Note that the type of target_modules
and feedforward_modules
has been changed from list to set (str is still valid).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know if my changes are correct!
target_ia3_l.data = target_ia3_l.data * 0.0 | ||
for adapter, weight in zip(adapters, weights): | ||
if adapter in target.ia3_l: | ||
current_adapter_ia3_l = target.ia3_l[adapter] | ||
else: | ||
continue | ||
target_ia3_l.data += current_adapter_ia3_l.data * weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is not correct: When using IA³, the IA³ weights have to be multiplied, not added, right? I.e. they should be initialized as 1.0 and then each IA³ weight is multiplied on top, not added. See how it's accomplished in the forward method of IA³:
peft/src/peft/tuners/ia3/layer.py
Lines 177 to 182 in 0c16918
ia3_scaling = 1 | |
for active_adapter in self.active_adapters: | |
if active_adapter not in self.ia3_l.keys(): | |
continue | |
dtype = self.ia3_l[active_adapter].dtype | |
ia3_scaling *= self.ia3_l[active_adapter].flatten() |
If this is correct, we encounter a second problem, namely that the weights
argument makes little sense: Since we just multiply each IA³ weight and each weight from weights
, due to commutativity, the order in weights
doesn't matter. Whether a user passes weights=[2, 3]
or weights=[3, 2]
makes no difference.
We could still leave it as is for consistency, but I would be afraid that this would confuse many users. Instead, we could also 1) remove the weights
argument entirely for IA³ or 2) only pass a single scalar to weights
, which is applied once to all weights (could be set as the initial value). WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback and review!
When using IA³, the IA³ weights have to be multiplied, not added, right?
This is true in the forward pass. The learned vectors
Let's assume we have two adapters that target [0.6, 0.4]
. The way I wanted to combine this adapters on a new adapter was:
If we also target the FF layers, we would compute the resulting vector using the same procedure.
the weights argument makes little sense
If we multiply vectors, yes. However, that would not result in a linear combination of vectors, which was my goal.
Let me know if this makes sense!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, not sure. Let's work with scalars for a second. Let's say we have one IA³ weight with value 2 and one with value 3. As they are multiplied consecutively on the input, I would expect that we should multiply by 6, not by their sum 5. Am I missing something?
Anyway, I thought why not just test if the results are right or not. For this, I changed the test you added to do this instead:
elif isinstance(config, (IA3Config)):
model = get_peft_model(model, config, adapter_list[0])
model = model.to(self.torch_device)
dummy_input = self.prepare_inputs_for_testing()
output0 = model(**dummy_input)[0]
model.add_adapter(adapter_list[1], config)
model.add_adapter(adapter_list[2], config)
model.set_adapter(adapter_list)
output1 = model(**dummy_input)[0]
model.merge_adapter()
output2 = model(**dummy_input)[0]
model.unmerge_adapter()
output3 = model(**dummy_input)[0]
# using addition
model.add_weighted_adapter(adapter_list, torch.ones(3) / 3, "merged-add")
model.set_adapter("merged-add")
output4 = model(**dummy_input)[0]
# using multiplication
model.add_weighted_adapter_mul(adapter_list, torch.ones(3), "merged-mul")
model.set_adapter("merged-mul")
output5 = model(**dummy_input)[0]
assert not torch.allclose(output0, output1)
torch.testing.assert_allclose(output1, output2)
torch.testing.assert_allclose(output1, output3)
torch.testing.assert_allclose(output1, output5) # passes
torch.testing.assert_allclose(output1, output4) # fails
As you can see, we test the outputs from an IA³ model with the 3 adapters active but unmerged vs merged vs merged using add_weighted_adapter
(your implementation) vs merged using add_weighted_adapter_mul
(my implementation using multiply). When I run the tests, the multiply version passes but the addition version fails, which makes me think that multiplying is the way to go.
If you want to replicate this result, it will require a few steps because our code isn't really set up to work with multiple active adapters yet, so I had to make a few ad hoc changes to even get this far. I created a PR on top of your branch containing those changes:
https://github.com/alexrs/peft/pull/1/files
Obviously, it should not be merged, it's just to show you what steps I took. WDYT, is this plausible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point! However, I'm not sure this is consistent with the LoRA implementation. As far as I understand, there are two different scenarios here:
1. Stacking Adapters: When using set_adapter
on multiple adapters, what we are doing is stacking adapters. That's how it works right now, and how it works in LoRA (I think!). This is equivalent to using combination_type=cat
in LoRA's add_weighted_adapter
(
peft/tests/test_custom_models.py
Lines 819 to 827 in e98df91
if tuner_method == "lora": | |
# create a weighted adapter combining both adapters and check that | |
# its output is same as setting multiple active adapters | |
peft_model.add_weighted_adapter( | |
["adapter_1", "adapter_2"], [1.0, 1.0], "new_combined_adapter", combination_type="cat" | |
) | |
peft_model.set_adapter("new_combined_adapter") | |
new_combined_output = peft_model(**X) | |
self.assertTrue(torch.allclose(new_combined_output, combined_output, atol=1e-5)) |
2. Linear combination of Adapters: In this case, we are not stacking adapters but combining them to create a new adapter that is a linear combination of the input adapters and the input weights. This is equivalent to
combination_type=linear
in LoRA's add_weighted_adapter
. If we change the code linked above to use linear
, the test fails:
if tuner_method == "lora":
# create a weighted adapter combining both adapters and check that
# its output is same as setting multiple active adapters
peft_model.add_weighted_adapter(
["adapter_1", "adapter_2"], [1.0, 1.0], "new_combined_adapter", combination_type="linear"
)
peft_model.set_adapter("new_combined_adapter")
new_combined_output = peft_model(**X)
self.assertTrue(torch.allclose(new_combined_output, combined_output, atol=1e-5))
And same if we decide to give equal weight to both adapters to sum to 1:
peft_model.add_weighted_adapter(
["adapter_1", "adapter_2"], [0.5, 0.5], "new_combined_adapter", combination_type="linear"
)
I guess a solution is to add the different combination_type
s to add_weighted_adapter
. Does this sound reasonable? Or do I have the wrong understanding of how this works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you're right in the sense that for IA³, it is not quite clear how to interpret the combination of results. Unfortunately, I don't think that there is any existing evidence for IA³ for what the best way for combining adapters is. I agree that we could offer multiple methods and that hopefully, with time, the best method will emerge. When it comes to which default to choose, I'd argue it's a nice property to have the same output for combining the adapters as if they were all active at once, WDYT?
Another possibility that come to mind would be to go for geometric mean, which seems appropriate for a multiplicative operation, but it wouldn't work for negative numbers, so has to be ruled out.
When it comes to naming the combination types, the analogy to LoRA is a bit difficult, because the mathematical operation is different. I think for IA³ it is necessary to think from first principles.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, I don't think that there is any existing evidence for IA³ for what the best way for combining adapters is
Agreed.
I'd argue it's a nice property to have the same output for combining the adapters as if they were all active at once, WDYT?
That makes sense! But as discussed above, it is not how it works in LoRA by default, is it?
I guess the way to proceed is to allow both multiplication and linear combination methods using different combination_types
, and setting the default to multiplication?
All in all, given that there is no evidence for what the best way for combining adapters is, I will try to run some experiments using both methods to get more clarity on this topic. Let me know if you have any suggestions or ideas for this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense! But as discussed above, it is not how it works in LoRA by default, is it?
Yes, but we cannot really compare the two as I mentioned. E.g. it would not make sense to have an "svd" method for IA³, so I think we shouldn't put too much stress on consistency here.
I will try to run some experiments using both methods to get more clarity on this topic. Let me know if you have any suggestions or ideas for this!
That would be fantastic. Loading and combining multiple LoRAs is mostly a thing in image generation AFAIK, so that's probably what I would investigate, but I'm not sure how well IA³ lends itself to image generation in general.
model.eval() | ||
_ = model(**dummy_input)[0] | ||
|
||
elif isinstance(config, (IA3Config)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Linear
for get_delta_weight
method. How should we test that the result is correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late reply. I had written a review but somehow forgot to send it.
Could you please also fix the merge conflict?
@@ -278,13 +280,15 @@ def _prepare_adapter_config(self, peft_config, model_config): | |||
if peft_config.target_modules is None: | |||
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING: | |||
raise ValueError("Please specify `target_modules` in `peft_config`") | |||
peft_config.target_modules = TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]] | |||
peft_config.target_modules = set( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, this is a bug in the existing code base.
) | ||
target.set_adapter(resetting_active_adapter) | ||
|
||
def _new_modules(self, adapters, module_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like having a separate method for this, but the name is not quite fitting. This combines different module names, right? So could you please adjust the name to reflect that? Also, please add a sentence to the docstring that explains what happens.
Hi @BenjaminBossan, sorry for the late reply. Some other stuff came up and I didn't have much time to look into this lately. I'll get back to it as soon as I can! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @alexrs for adding new utils for IA3 🤗! The discussion between you and @BenjaminBossan was quite interesting and insightful. I do agree that combining IA3 adapters need to be thought from first principles as it follows multiplicative operators. At the same time, linear combination of IA3 adapters is a nice feature. I believe we can support it given that there is a clear documentation explaining how it is different from having multiple active adapters and it being the weighted average combination of IA3 adapters. Thank you for adding the support for deleting the iA3 adapters, useful when working with multiple adapters.
Description The job of deleting an adapter is now transferred to the adapter layer, instead of the adapter model. This makes it easier for users or other libraries who don't use the adapter model to delete adapters. Implementation The code should now be more generic, relying less on hard-coded attributes. As a precaution, I also changed the type of adapter_layer_names from list to tuple, as it should not be mutated. When deleting the active adapter, the logic for choosing the new active adapter has been changed slightly to ensure consistency across layers. In practice, this should rarely make a difference. An error is now raised if the last remaining adapter is deleted. Test coverage has been increased: - Deleting adapters is now also tested for custom models. - It is also tested for LoHa, LoKr, not only LoRA. - I added a test for deleting the non-active adapter. Not implemented I did not add adapter deletion to IA³, since it is included in huggingface#980. LMK if it should be added here instead.
Description The job of deleting an adapter is now transferred to the adapter layer, instead of the adapter model. This makes it easier for users or other libraries who don't use the adapter model to delete adapters. Implementation The code should now be more generic, relying less on hard-coded attributes. As a precaution, I also changed the type of adapter_layer_names from list to tuple, as it should not be mutated. When deleting the active adapter, the logic for choosing the new active adapter has been changed slightly to ensure consistency across layers. In practice, this should rarely make a difference. An error is now raised if the last remaining adapter is deleted. Test coverage has been increased: - Deleting adapters is now also tested for custom models. - It is also tested for LoHa, LoKr, not only LoRA. - I added a test for deleting the non-active adapter. Not implemented I did not add adapter deletion to IA³, since it is included in #980. LMK if it should be added here instead.
@alexrs Do you still plan to work on this PR? Note that there are some merge conflicts now due to recent changes in PEFT. I think they should be straightforward to resolve, but let me know if you need help. |
@BenjaminBossan Hey! Yes, sorry for all the delay. I'm planning on working on it, but I'll probably need some more time. Given that this PR is quite big, if there are some parts with higher priority (ie adapter deletion), I can extract those to separate PRs so we can merge them faster. WDYT? |
No worries, take the time you need. I agree it is a good idea to separate features into multiple PRs when possible. |
@BenjaminBossan As we discussed, I extracted the |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
@alexrs Are you interested in resurrecting this PR? |
@BenjaminBossan Yes! I'm finishing some work for my MSc thesis but I hope to have some time in the next few weeks. Are we pressed for time on this? |
We're not pressed, take your time. I'm just asking in case you want someone else to take over. Good luck with your thesis. |
Problem
IA3Model
.Solution
delete_adapter
method based on the one inLoraModel
.add_weighted_adapter
method based on the one inLoraModel
. This method, however, is a simplified version of the LoRA one. Ascat
andsvd
options.Other minor modifications
self.scaling
inIA3Layer
as it was not used. I believe this was a leftover from the LoRA implementation this layer seems to be based on.Discussion
I have tested this locally using a simple script (see below) but I have not run any automated tests yet. What is the best way to test these changes?
Test script