Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine multiple (IA)^3 Adapters and delete (IA)^3 adapters #980

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/peft/tuners/ia3/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(
out_features: int,
is_feedforward: bool,
):
self.scaling = {}
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
self.ia3_l = nn.ParameterDict({})
# Mark the weight as unmerged
self._disable_adapters = False
Expand Down Expand Up @@ -166,7 +165,6 @@ def _linear(self, input: torch.Tensor) -> torch.Tensor:

def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
self.unmerge()
Expand Down
116 changes: 110 additions & 6 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import operator
import re
import warnings
from dataclasses import asdict
from dataclasses import asdict, replace
from enum import Enum
from functools import reduce

import torch
from transformers.pytorch_utils import Conv1D
Expand All @@ -27,6 +28,7 @@
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING,
ModulesToSaveWrapper,
_freeze_adapter,
_get_submodules,
)

Expand Down Expand Up @@ -278,13 +280,15 @@ def _prepare_adapter_config(self, peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]]
peft_config.target_modules = set(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, this is a bug in the existing code base.

TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]]
)
if peft_config.feedforward_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING:
raise ValueError("Please specify `feedforward_modules` in `peft_config`")
peft_config.feedforward_modules = TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[
model_config["model_type"]
]
peft_config.feedforward_modules = set(
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[model_config["model_type"]]
)
return peft_config

def merge_and_unload(self, safe_merge: bool = False):
Expand Down Expand Up @@ -336,3 +340,103 @@ def merge_and_unload(self, safe_merge: bool = False):
self._replace_module(parent, target_name, new_module, target)

return self.model

def delete_adapter(self, adapter_name: str):
"""
Deletes an existing adapter.

Args:
adapter_name (str): Name of the adapter to be deleted.
"""
if adapter_name not in self.peft_config:
raise ValueError(f"Adapter {adapter_name} does not exist")
del self.peft_config[adapter_name]

key_list = [key for key, _ in self.model.named_modules() if "ia3" not in key]
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, IA3Layer):
if adapter_name in target.ia3_l:
target.ia3_l.pop(adapter_name)
if adapter_name in target.active_adapters:
resetting_active_adapter = (
list(self.peft_config.keys())[0] if len(self.peft_config) > 0 else "default"
)
warnings.warn(
f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to {resetting_active_adapter}. "
)
target.set_adapter(resetting_active_adapter)

def _new_modules(self, adapters, module_type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like having a separate method for this, but the name is not quite fitting. This combines different module names, right? So could you please adjust the name to reflect that? Also, please add a sentence to the docstring that explains what happens.

"""
Args:
adapters (`list`):
List of adapter names to be merged.
module_type (`str`):
Type of the module to be merged.
"""
module_types = [type(getattr(self.peft_config[adapter], module_type)) for adapter in adapters]
if not module_types:
raise ValueError(f"Found no adapter matching the names in {adapters}")
if len(set(module_types)) > 1:
raise ValueError(
"all adapter configs should follow the same target modules type. "
f"Combining adapters with `{module_type}` type being a mix of list/set and string is not supported."
)
if module_types[0] == str:
new_modules = "|".join(f"({getattr(self.peft_config[adapter], module_type)})" for adapter in adapters)
elif module_types[0] == set:
new_modules = reduce(
operator.or_, (getattr(self.peft_config[adapter], module_type) for adapter in adapters)
)
else:
raise TypeError(f"Invalid type {module_types[0]} found in {module_type}")
return new_modules

def add_weighted_adapter(self, adapters, weights, adapter_name):
"""
This method adds a new adapter by merging the given adapters with the given weights.

Args:
adapters (`list`):
List of adapter names to be merged.
weights (`list`):
List of weights for each adapter.
adapter_name (`str`):
Name of the new adapter.
"""
if adapter_name in list(self.peft_config.keys()):
return
for adapter in adapters:
if adapter not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter} does not exist")

new_target_modules = self._new_modules(adapters, "target_modules")
new_feedforward_modules = self._new_modules(adapters, "feedforward_modules")

self.peft_config[adapter_name] = replace(
self.peft_config[adapters[0]],
target_modules=new_target_modules,
feedforward_modules=new_feedforward_modules,
)
self.inject_adapter(self.model, adapter_name)

# Do we really need that?
_freeze_adapter(self.model, adapter_name)

key_list = [key for key, _ in self.model.named_modules() if "ia3" not in key]
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, IA3Layer):
if adapter_name in target.ia3_l:
target_ia3_l = target.ia3_l[adapter_name]
else:
continue

target_ia3_l.data = target_ia3_l.data * 0.0
for adapter, weight in zip(adapters, weights):
if adapter in target.ia3_l:
current_adapter_ia3_l = target.ia3_l[adapter]
else:
continue
target_ia3_l.data += current_adapter_ia3_l.data * weight
Comment on lines +436 to +442
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not correct: When using IA³, the IA³ weights have to be multiplied, not added, right? I.e. they should be initialized as 1.0 and then each IA³ weight is multiplied on top, not added. See how it's accomplished in the forward method of IA³:

ia3_scaling = 1
for active_adapter in self.active_adapters:
if active_adapter not in self.ia3_l.keys():
continue
dtype = self.ia3_l[active_adapter].dtype
ia3_scaling *= self.ia3_l[active_adapter].flatten()

If this is correct, we encounter a second problem, namely that the weights argument makes little sense: Since we just multiply each IA³ weight and each weight from weights, due to commutativity, the order in weights doesn't matter. Whether a user passes weights=[2, 3] or weights=[3, 2] makes no difference.

We could still leave it as is for consistency, but I would be afraid that this would confuse many users. Instead, we could also 1) remove the weights argument entirely for IA³ or 2) only pass a single scalar to weights, which is applied once to all weights (could be set as the initial value). WDYT?

Copy link
Contributor Author

@alexrs alexrs Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback and review!

When using IA³, the IA³ weights have to be multiplied, not added, right?

This is true in the forward pass. The learned vectors $l$ are multiplied with (in the case of attention) $K^T$ and $Q$. However, here we are not considering the Key and Value matrices, only learned vectors $l$ (as far as I understand), so my approach here was to compute a linear combination of the vectors (which is what we do in LoRA?).

Let's assume we have two adapters that target $K$ and $V$ with associated vectors $l_K$ and $l_V$, and weights [0.6, 0.4]. The way I wanted to combine this adapters on a new adapter was:

$l_K^{\text{new}} = l_K^1 * w_1 + l_K^2 * w_2$
$l_V^{\text{new}} = l_V^1 * w_1 + l_V^2 * w_2$

If we also target the FF layers, we would compute the resulting vector using the same procedure.

the weights argument makes little sense

If we multiply vectors, yes. However, that would not result in a linear combination of vectors, which was my goal.

Let me know if this makes sense!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not sure. Let's work with scalars for a second. Let's say we have one IA³ weight with value 2 and one with value 3. As they are multiplied consecutively on the input, I would expect that we should multiply by 6, not by their sum 5. Am I missing something?

Anyway, I thought why not just test if the results are right or not. For this, I changed the test you added to do this instead:

        elif isinstance(config, (IA3Config)):
            model = get_peft_model(model, config, adapter_list[0])
            model = model.to(self.torch_device)
            dummy_input = self.prepare_inputs_for_testing()
            output0 = model(**dummy_input)[0]

            model.add_adapter(adapter_list[1], config)
            model.add_adapter(adapter_list[2], config)

            model.set_adapter(adapter_list)
            output1 = model(**dummy_input)[0]

            model.merge_adapter()
            output2 = model(**dummy_input)[0]

            model.unmerge_adapter()
            output3 = model(**dummy_input)[0]

            # using addition
            model.add_weighted_adapter(adapter_list, torch.ones(3) / 3, "merged-add")
            model.set_adapter("merged-add")
            output4 = model(**dummy_input)[0]

            # using multiplication
            model.add_weighted_adapter_mul(adapter_list, torch.ones(3), "merged-mul")
            model.set_adapter("merged-mul")
            output5 = model(**dummy_input)[0]

            assert not torch.allclose(output0, output1)
            torch.testing.assert_allclose(output1, output2)
            torch.testing.assert_allclose(output1, output3)
            torch.testing.assert_allclose(output1, output5)  # passes
            torch.testing.assert_allclose(output1, output4)  # fails

As you can see, we test the outputs from an IA³ model with the 3 adapters active but unmerged vs merged vs merged using add_weighted_adapter (your implementation) vs merged using add_weighted_adapter_mul (my implementation using multiply). When I run the tests, the multiply version passes but the addition version fails, which makes me think that multiplying is the way to go.

If you want to replicate this result, it will require a few steps because our code isn't really set up to work with multiple active adapters yet, so I had to make a few ad hoc changes to even get this far. I created a PR on top of your branch containing those changes:

https://github.com/alexrs/peft/pull/1/files

Obviously, it should not be merged, it's just to show you what steps I took. WDYT, is this plausible?

Copy link
Contributor Author

@alexrs alexrs Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point! However, I'm not sure this is consistent with the LoRA implementation. As far as I understand, there are two different scenarios here:
1. Stacking Adapters: When using set_adapter on multiple adapters, what we are doing is stacking adapters. That's how it works right now, and how it works in LoRA (I think!). This is equivalent to using combination_type=cat in LoRA's add_weighted_adapter (

if tuner_method == "lora":
# create a weighted adapter combining both adapters and check that
# its output is same as setting multiple active adapters
peft_model.add_weighted_adapter(
["adapter_1", "adapter_2"], [1.0, 1.0], "new_combined_adapter", combination_type="cat"
)
peft_model.set_adapter("new_combined_adapter")
new_combined_output = peft_model(**X)
self.assertTrue(torch.allclose(new_combined_output, combined_output, atol=1e-5))
)
2. Linear combination of Adapters: In this case, we are not stacking adapters but combining them to create a new adapter that is a linear combination of the input adapters and the input weights. This is equivalent to combination_type=linear in LoRA's add_weighted_adapter. If we change the code linked above to use linear, the test fails:

        if tuner_method == "lora":
            # create a weighted adapter combining both adapters and check that
            # its output is same as setting multiple active adapters
            peft_model.add_weighted_adapter(
                ["adapter_1", "adapter_2"], [1.0, 1.0], "new_combined_adapter", combination_type="linear"
            )
            peft_model.set_adapter("new_combined_adapter")
            new_combined_output = peft_model(**X)
            self.assertTrue(torch.allclose(new_combined_output, combined_output, atol=1e-5))

And same if we decide to give equal weight to both adapters to sum to 1:

            peft_model.add_weighted_adapter(
                ["adapter_1", "adapter_2"], [0.5, 0.5], "new_combined_adapter", combination_type="linear"
            )

I guess a solution is to add the different combination_types to $(IA)^3$'s add_weighted_adapter. Does this sound reasonable? Or do I have the wrong understanding of how this works?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right in the sense that for IA³, it is not quite clear how to interpret the combination of results. Unfortunately, I don't think that there is any existing evidence for IA³ for what the best way for combining adapters is. I agree that we could offer multiple methods and that hopefully, with time, the best method will emerge. When it comes to which default to choose, I'd argue it's a nice property to have the same output for combining the adapters as if they were all active at once, WDYT?

Another possibility that come to mind would be to go for geometric mean, which seems appropriate for a multiplicative operation, but it wouldn't work for negative numbers, so has to be ruled out.

When it comes to naming the combination types, the analogy to LoRA is a bit difficult, because the mathematical operation is different. I think for IA³ it is necessary to think from first principles.

Copy link
Contributor Author

@alexrs alexrs Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I don't think that there is any existing evidence for IA³ for what the best way for combining adapters is

Agreed.

I'd argue it's a nice property to have the same output for combining the adapters as if they were all active at once, WDYT?

That makes sense! But as discussed above, it is not how it works in LoRA by default, is it?

I guess the way to proceed is to allow both multiplication and linear combination methods using different combination_types, and setting the default to multiplication?

All in all, given that there is no evidence for what the best way for combining adapters is, I will try to run some experiments using both methods to get more clarity on this topic. Let me know if you have any suggestions or ideas for this!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense! But as discussed above, it is not how it works in LoRA by default, is it?

Yes, but we cannot really compare the two as I mentioned. E.g. it would not make sense to have an "svd" method for IA³, so I think we shouldn't put too much stress on consistency here.

I will try to run some experiments using both methods to get more clarity on this topic. Let me know if you have any suggestions or ideas for this!

That would be fantastic. Loading and combining multiple LoRAs is mostly a thing in image generation AFAIK, so that's probably what I would investigate, but I'm not sure how well IA³ lends itself to image generation in general.

142 changes: 85 additions & 57 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
get_peft_model_state_dict,
prepare_model_for_int8_training,
)
from peft.tuners.ia3 import IA3Layer
from peft.tuners.lora import LoraLayer
from peft.utils import _get_submodules, infer_device

Expand Down Expand Up @@ -810,7 +811,7 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
model.set_adapter(adapter_to_delete)
model = model.to(self.torch_device)

if config.peft_type not in ("LORA"):
if config.peft_type not in ("LORA", "IA3"):
with self.assertRaises(AttributeError):
model.delete_adapter(adapter_to_delete)
else:
Expand All @@ -831,6 +832,8 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
"lora_dropout",
]:
self.assertFalse(adapter_to_delete in getattr(target, attr))
elif isinstance(target, IA3Layer):
self.assertFalse(adapter_to_delete in target.ia3_l)

def _test_unload_adapter(self, model_id, config_cls, config_kwargs):
model = self.transformers_class.from_pretrained(model_id)
Expand Down Expand Up @@ -870,70 +873,95 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
base_model_name_or_path=model_id,
**config_kwargs,
)
if not isinstance(config, (LoraConfig)):
if not isinstance(config, (LoraConfig, IA3Config)):
return
model = get_peft_model(model, config, adapter_list[0])
model.add_adapter(adapter_list[1], config)
model.add_adapter(adapter_list[2], replace(config, r=20))
model = model.to(self.torch_device)

# test re-weighting single adapter
model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting")

# test svd re-weighting with multiple adapters
model.add_weighted_adapter(adapter_list[1:], weight_list[1:], "multi_adapter_svd_reweighting")
if isinstance(config, (LoraConfig)):
model = get_peft_model(model, config, adapter_list[0])
model.add_adapter(adapter_list[1], config)
model.add_adapter(adapter_list[2], replace(config, r=20))
model = model.to(self.torch_device)
# test re-weighting single adapter
model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting")

# test cat re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[1:], weight_list[1:], "multi_adapter_cat_reweighting", combination_type="cat"
)
# test svd re-weighting with multiple adapters
model.add_weighted_adapter(adapter_list[1:], weight_list[1:], "multi_adapter_svd_reweighting")

# test linear re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[:2], weight_list[:2], "multi_adapter_linear_reweighting", combination_type="linear"
)
# test cat re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[1:], weight_list[1:], "multi_adapter_cat_reweighting", combination_type="cat"
)

with self.assertRaises(ValueError):
# test linear re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[1:],
weight_list[1:],
"multi_adapter_linear_reweighting_uneven_r",
combination_type="linear",
adapter_list[:2], weight_list[:2], "multi_adapter_linear_reweighting", combination_type="linear"
)

new_adapters = [
"single_adapter_reweighting",
"multi_adapter_svd_reweighting",
"multi_adapter_cat_reweighting",
"multi_adapter_linear_reweighting",
]
for new_adapter in new_adapters:
self.assertTrue(new_adapter in model.peft_config)

key_list = [key for key, _ in model.named_modules() if "lora" not in key]
for key in key_list:
_, target, _ = _get_submodules(model, key)
if isinstance(target, LoraLayer):
for adapter_name in new_adapters:
if "single" in adapter_name:
new_delta_weight = target.get_delta_weight(adapter_name)
weighted_original_delta_weights = target.get_delta_weight(adapter_list[0]) * weight_list[0]
self.assertTrue(
torch.allclose(new_delta_weight, weighted_original_delta_weights, atol=1e-4, rtol=1e-4)
)
elif "svd" in adapter_name:
self.assertTrue(target.r[adapter_name] == 20)
elif "linear" in adapter_name:
self.assertTrue(target.r[adapter_name] == 8)
elif "cat" in adapter_name:
self.assertTrue(target.r[adapter_name] == 28)

for adapter_name in new_adapters:
# ensuring new adapters pass the forward loop
model.set_adapter(adapter_name)
dummy_input = self.prepare_inputs_for_testing()
model.eval()
_ = model(**dummy_input)[0]
with self.assertRaises(ValueError):
model.add_weighted_adapter(
adapter_list[1:],
weight_list[1:],
"multi_adapter_linear_reweighting_uneven_r",
combination_type="linear",
)

new_adapters = [
"single_adapter_reweighting",
"multi_adapter_svd_reweighting",
"multi_adapter_cat_reweighting",
"multi_adapter_linear_reweighting",
]
for new_adapter in new_adapters:
self.assertTrue(new_adapter in model.peft_config)

key_list = [key for key, _ in model.named_modules() if "lora" not in key]
for key in key_list:
_, target, _ = _get_submodules(model, key)
if isinstance(target, LoraLayer):
for adapter_name in new_adapters:
if "single" in adapter_name:
new_delta_weight = target.get_delta_weight(adapter_name)
weighted_original_delta_weights = target.get_delta_weight(adapter_list[0]) * weight_list[0]
self.assertTrue(
torch.allclose(new_delta_weight, weighted_original_delta_weights, atol=1e-4, rtol=1e-4)
)
elif "svd" in adapter_name:
self.assertTrue(target.r[adapter_name] == 20)
elif "linear" in adapter_name:
self.assertTrue(target.r[adapter_name] == 8)
elif "cat" in adapter_name:
self.assertTrue(target.r[adapter_name] == 28)

for adapter_name in new_adapters:
# ensuring new adapters pass the forward loop
model.set_adapter(adapter_name)
dummy_input = self.prepare_inputs_for_testing()
model.eval()
_ = model(**dummy_input)[0]

elif isinstance(config, (IA3Config)):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Linear for $(IA)^3$ does not have a get_delta_weight method. How should we test that the result is correct?

model = get_peft_model(model, config, adapter_list[0])
model.add_adapter(adapter_list[1], config)
model.add_adapter(adapter_list[2], config)
model = model.to(self.torch_device)
# single adapter re-weighting and multi adapter linear re-weighting
# Note: IA3 only supports linear re-weighting
model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting")
model.add_weighted_adapter(adapter_list[:2], weight_list[:2], "multi_adapter_linear_reweighting")

new_adapters = [
"single_adapter_reweighting",
"multi_adapter_linear_reweighting",
]
for new_adapter in new_adapters:
self.assertTrue(new_adapter in model.peft_config)

for adapter_name in new_adapters:
# ensuring new adapters pass the forward loop
model.set_adapter(adapter_name)
dummy_input = self.prepare_inputs_for_testing()
model.eval()
_ = model(**dummy_input)[0]

def _test_disable_adapter(self, model_id, config_cls, config_kwargs):
task_type = config_kwargs.get("task_type")
Expand Down