Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixed adapter types #1069

Closed

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Nov 1, 2023

This is a POC to show how we could achieve mixing adapter types such as LoRA and LoKr.

Description

The very general idea is that we can already mix multiple adapters of the same type, e.g. add two LoRA adapters, but right now we fail when trying to mix different types. This restriction has been lifted by adding a new class PeftMixedModel which deals with different adapter types.

The usage looks something like this:

base_model = ...
config0 = LoraConfig(...)
# set mixed=True
peft_model = get_peft_model(base_model, config0, mixed=True)
config1 = LoHaConfig(...)
peft_model.add_adapter(config1, "other")
peft_model.set_adapter(["default", "other"])
peft_model(x)

At this point, both adapters are active at the same time.

Existing code should not be affected by this change, since users need to opt into this behavior by setting mixed=True.

Also interesting is that this method can be used for a single adapter type but with very different configs. Right now, we have limited support for that (e.g. for LoRA, different r values by using rank_pattern), but with this, we don't need to special case the differing arguments anymore.

Implementation

Apart from adding the new PeftMixedModel class to replace PeftModel, I added a new class LycorisModel which replaces LoraModel, LoHaModel etc. This class checks the config type and then uses the corresponding LoraModel, LoHaModel etc. to create the adapter.

Another crucial change I had to make was to adopt the "base layer pattern". This is the pattern that was, for instance, used to speed up initialization in LoRA bnb layers in PR #994.

The main change is that the adapter layer wraps the original layer and calls forward on that layer, instead of doing stuff like this:

F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

which completely circumvents the call to the target layer's forward method. With the base layer pattern, we now call the target layer's forward method. Therefore, if the target layer is another adapter layer, we call its forward method correctly.

This change has the nice side benefit that we no longer need to use _init_empty_weights -- in fact, we don't initialize any of the target layer's weights anymore, since we have a reference to it.

Note that same as for the bnb layers, this should not be backwards incompatible, since the adapter weights and their state_dicts are not affected by this change.

I pondered the possibility to implement this via hooks, which may be more elegant, but that would require larger changes in the adapter layers like LoRA Linear.forward to be rewritten to only return the diff. The chosen approach seemed less disruptive to me.

Somewhat unrelated changes

  1. During debugging, I got very annoyed with the fact that the reprs of adapter layers and normal PyTorch layers are hard to distinguish, e.g. the type is just "Linear". Now, for adapter layers, it is prefixed by the adapter type, e.g. "lora.Linear".
  2. For LoHa (and in the future, LoKr), I had to change the init of weights when using init_weights=False. This is because of what is discussed in Numerical instabilities with LoHa #1058.

TODOs

  • For now, I only added this capability for LoRA and LoHa as a POC. It needs to be added to LoKr, AdaLora and LoRA bnb too. done
  • The unit tests are very rudimentary right now, only a simple model is tested in two settings. Much broader coverage, though still room for improvements
  • There is no documentation so far.
  • I'm not yet sure if the same logic can be applied to IA³ or if it may fail because IA³ can apply its scaling to the input, not the output.
  • It is currently not possible to represent a mixed adapter model as a single config. I think we can come up with a solution but I don't think it is necessary for a first version of this.
  • I haven't checked modules_to_save with mixed adapters yet (I guess we have to forbid those except for one adapter -- or make them a completely independent adapter type).
  • Merging may be tricky, not sure. Merging, unloading, and disabling works
  • LycorisModel may not be the best name, as non-Lycoris adapters may be supported, suggestions are welcome.
  • Saving and loading is not yet implemented for mixed models.

This is a POC to show how we could achieve mixing adapter types such as
LoRA and LoKr.

Description

The very general idea is that we can already mix multiple adapters of
the same type, e.g. add two LoRA adapters, but right now we fail when
trying to mix different types. This restriction has been lifted by
adding a new class PeftMixedModel which deals with different adapter
types.

The usage looks something like this:

    base_model = ...
    config0 = LoraConfig(...)
    # set mixed=True
    peft_model = get_peft_model(base_model, config0, mixed=True)
    config1 = LoHaConfig(...)
    peft_model.add_adapter(config1, "other")
    peft_model.set_adapter(["default", "other"])

At this point, both adapters are active at the same time.

Existing code should not be affected by this change, since users need to
opt into this behavior by setting mixed=True.

Also interesting is that this method can be used for a single adapter
type but with very different configs. Right now, we have limited support
for that (e.g. for LoRA, different r values by using rank_pattern), but
with this, we don't need to special case the differing arguments anymore.

Implementation

Apart from adding the new PeftMixedModel class to replace PeftModel, I
added a new class LycorisModel which replaces LoraModel, LoHaModel etc.
This class checks the config type and then uses the corresponding
LoraModel, LoHaModel etc. to create the adapter.

Another crucial change I had to make was to adopt the "base layer
pattern". This is the pattern that was, for instance, used to speed up
initialization in LoRA bnb layers in PR huggingface#994.

The main change is that the adapter layer wraps the original layer and
calls forward on that layer, instead of doing stuff like this:

    F.linear(
        input, transpose(self.weight, self.fan_in_fan_out)
    )

which completely circumvents the call to the target layer's forward
method. With the base layer pattern, we now call the target layer's
forward method. Therefore, if the target layer is another adapter layer,
we call its forward method correctly.

This change has the nice side benefit that we no longer need to use
_init_empty_weight -- in fact, we don't initialize any of the target
layer's weights anymore, since we have a reference to it.

Note that same as for the bnb layers, this should not be backwards
incompatible, since the adapter weights and their state_dicts are not
affected by this change.

Somewhat unrelated changes

During debugging, I got very annoyed with the fact that the reprs of
adapter layers and normal PyTorch layers are hard to distinguish, e.g.
the type is just "Linear". Now, for adapter layers, it is prefixed by
the adapter type, e.g. "lora.Linear".

TODOs

- [ ] For now, I only added this capability for LoRA and LoHa as a POC.
  It needs to be added to LoKr and AdaLora too.
- [ ] The unit tests are very rudimentary right now, only a simple model
  is tested in two settings.
- [ ] There is no documentation so far.
- [ ] I'm not yet sure if the same logic can be applied to IA³ or if it
  may fail because IA³ can apply its scaling to the input, not the output
- [ ] It is currently not possible to represent a mixed adapter model as
  a single config. I think we can come up with a solution but I don't
  think it is necessary for a first version of this.
@BenjaminBossan BenjaminBossan marked this pull request as draft November 1, 2023 16:47
@BenjaminBossan
Copy link
Member Author

@pacman100 @younesbelkada This draft PR took quite some time to get to this state. Before putting more work into this, I would like to know if we want to continue on this path or not, so your input is welcome. Also interested in your opinion @kovalexal, since we discussed this topic in #1006.

Docs don't build otherwise...
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Some tests are still failing, probably because of mixups when AdaLora
checks are performed, because it checks for Lora instances, so it might
result in false positives if the tuners are mixed.
Seems to work only on newer Python versions.
@kovalexal
Copy link
Contributor

@BenjaminBossan, thank you for actively working on this problem, I am sure that it will be a great addition to PEFT!

At first glance, it looks quite stunning. Am I right, that when using the base_adapter pattern the resulting modified model will look like an onion (adapters will be wrapped in each other sequentially)?

I've tried to use your code and found out several points of improvement:

  • Right now there is a problem with adding the same type of adapter that already exists. Here is a small code snippet that demonstrates it:
import torch
import torch.nn as nn
from peft import LoraConfig, LoHaConfig, LoKrConfig, get_peft_model, get_peft_model_state_dict

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin0 = nn.Linear(10, 10)
        self.lin1 = nn.Linear(10, 10)

    def forward(self, x):
        x = self.lin0(x)
        x = self.lin1(x)
        return x
    
model = Model()

# Add LoRA
config0 = LoraConfig(r=4, lora_alpha=4, target_modules=["lin0", "lin1"])
peft_model = get_peft_model(model, config0, mixed=True)

# Add LoHa
config1 = LoHaConfig(r=4, alpha=4, target_modules=["lin0", "lin1"])
peft_model.add_adapter("other", config1)

# Add LoRA again
config2 = LoraConfig(r=4, lora_alpha=4, target_modules=["lin0", "lin1"])
peft_model.add_adapter("other_lora", config2)
# AttributeError: 'Linear' object has no attribute 'weight'

I am almost sure that in order to properly cope with that we need to recursively search for the corresponding adapter type and if it exists, we should add the new adapter to it, otherwise wrap it with the new adapter type.

  • I know that this is a draft PR, but right now there is a problem with get_peft_model_state_dict - it just returns the outer adapter weights (plus it seems that there are additional base_model wraps in it):
PeftMixedModel(
  (base_model): LycorisModel(
    (model): Model(
      (lin0): loha.Linear(
        (base_layer): lora.Linear(
          (base_layer): Linear(in_features=10, out_features=10, bias=True)
          (lora_dropout): ModuleDict(
            (default): Identity()
          )
          (lora_A): ModuleDict(
            (default): Linear(in_features=10, out_features=4, bias=False)
          )
          (lora_B): ModuleDict(
            (default): Linear(in_features=4, out_features=10, bias=False)
          )
          (lora_embedding_A): ParameterDict()
          (lora_embedding_B): ParameterDict()
        )
        (hada_w1_a): ParameterDict(  (other): Parameter containing: [torch.FloatTensor of size 10x4])
        (hada_w1_b): ParameterDict(  (other): Parameter containing: [torch.FloatTensor of size 4x10])
        (hada_w2_a): ParameterDict(  (other): Parameter containing: [torch.FloatTensor of size 10x4])
        (hada_w2_b): ParameterDict(  (other): Parameter containing: [torch.FloatTensor of size 4x10])
        (hada_t1): ParameterDict()
        (hada_t2): ParameterDict()
      )
      (lin1): loha.Linear(
        (base_layer): lora.Linear(
          (base_layer): Linear(in_features=10, out_features=10, bias=True)
          (lora_dropout): ModuleDict(
            (default): Identity()
          )
          (lora_A): ModuleDict(
            (default): Linear(in_features=10, out_features=4, bias=False)
          )
          (lora_B): ModuleDict(
            (default): Linear(in_features=4, out_features=10, bias=False)
          )
          (lora_embedding_A): ParameterDict()
          (lora_embedding_B): ParameterDict()
        )
        (hada_w1_a): ParameterDict(  (other): Parameter containing: [torch.FloatTensor of size 10x4])
        (hada_w1_b): ParameterDict(  (other): Parameter containing: [torch.FloatTensor of size 4x10])
        (hada_w2_a): ParameterDict(  (other): Parameter containing: [torch.FloatTensor of size 10x4])
        (hada_w2_b): ParameterDict(  (other): Parameter containing: [torch.FloatTensor of size 4x10])
        (hada_t1): ParameterDict()
        (hada_t2): ParameterDict()
      )
    )
  )
)
list(get_peft_model_state_dict(peft_model).keys())

#['base_model.model.lin0.base_layer.base_layer.lora_A.weight', # <- shouldn't there be just `base_model.model.lin0.base_layer.lora_A.weight`?
# 'base_model.model.lin0.base_layer.base_layer.lora_B.weight', # <- shouldn't there be just `base_model.model.lin0.base_layer.lora_B.weight`?
# 'base_model.model.lin1.base_layer.lora_A.weight',
# 'base_model.model.lin1.base_layer.lora_B.weight']

I've spent some time reading the IA3 code and it seems that now I truly understand what you were talking about in terms of commutativity. I'll try to investigate in terms of how webui copes with that, so I'll be back in some time;)

@BenjaminBossan
Copy link
Member Author

At first glance, it looks quite stunning. Am I right, that when using the base_adapter pattern the resulting modified model will look like an onion (adapters will be wrapped in each other sequentially)?

Indeed, this is a somewhat unfortunate side-effect of this approach. So e.g. if we create a model that is LoRA1 > LoRA2 > LoHa1 > LoHa2, we will end up with one additional layer of nesting, where the LoHa adapter will wrap the LoRA adapter (which wraps the original layer). However, when we do LoRA1 > LoHa1 > LoRA2 > LoHa2, we will end up with 3 additional layers of nesting, even though the resulting model is the same on paper.

I think this is not a huge deal and we can accept some limitations on the mixed adapter class, as it would still enable features that are simply impossible now. But I haven't figured out a way to avoid this without some complicated logic or a bigger refactor (e.g. switching to hooks). As is, I think it's not going to add a significant overhead and as long as the adapters are commutative, we could encourage the users to add them in groups of same type.

there is a problem with get_peft_model_state_dict

Yes, I have purposefully excluded persistence from the scope of this PR, as it will require some extra work. I have yet to fully think through how important persistence is for a typical user. Right now, I imagine that the most common use case for this feature would be:

  1. The user (or someone else) trains multiple adapters independently, e.g. LoRA1, LoRA2, and LoHa1.
  2. The user wants to use all 3 at once.
  3. The user loads the base model with LoRA1.
  4. The user adds LoRA2.
  5. The user adds LoHa1.
  6. The user does inference.

If the user wants to repeat the whole process, I would imagine that they just go through steps 3-6 again, instead of saving the final model with all 3 adapters, which is why I would put persistence low when it comes to priority. However, I do think that merging could be a desirable feature, so I plan to support it. If you think that my imaginary use case does not resemble the most common one, please let me know.

(For the same reason, I also haven't worked on implementing a mapping between mixed adapter and PEFT config.)

Similarly, I think that training is not a high priority for mixed adapter types, so I would also not specifically try to make it work.

In update_layer, not in __init__.
Decreased tolerance, as one test currently fails on Windows, presumably
due to precision.

Also, better test function names by providing a name_func.
As side effects, added the prefix attribute on LoRA for consistency and
added safe merging on LoHa, LoKr.
Add test for deeply nested models
@BenjaminBossan
Copy link
Member Author

  • Right now there is a problem with adding the same type of adapter that already exists. Here is a small code snippet that demonstrates it:

@kovalexal Indeed there was a bug, it is now fixed and your example should work.

BenjaminBossan added a commit that referenced this pull request Nov 16, 2023
Description

Refactor all tuners (where it applies, i.e. not prompt tuning) to use
the "base layer pattern". This means that the adapter layer will always
hold a reference to the original layer that it modifies. This pattern is
already partly used (e.g. LoRA bnb, gptq layers), now it is consistently
used everywhere when applicable.

This PR is a companion PR to #1069, where I first added these changes.
They are now extracted to a separate PR to make code review easier and
to advance more quickly.

Implementation

The main change is that the adapter layer wraps the original layer and
calls forward on that layer, instead of doing stuff like this:

F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

which completely circumvents the call to the target layer's forward
method. With the base layer pattern, we now call the target layer's
forward method. Therefore, if the target layer is another adapter
layer (which will be crucial for mixed adapters), we call its forward
method correctly. Also, this should allow passing extra arguments, like
lora_scale to forward.

This change has the nice side benefit that we no longer need to use
_init_empty_weights -- in fact, we don't initialize any of the target
layer's weights anymore, since we have a reference to it. There is thus
no risk of having slow but superfluous initialization of layers.

Moreover, I could greatly simplify merge_and_unload by just using the
base_layer instead of having to create a completely new layer. For
OPT-350m, this results in a 15x speedup.

Note that same as for the bnb layers, this should be backwards
incompatible, since the adapter weights and their state_dicts are not
affected by this change. I used #1115 for regression testing.

Somewhat unrelated changes

During debugging, I got very annoyed with the fact that the reprs of
adapter layers and normal PyTorch layers are hard to distinguish, e.g.
the type is just "Linear". Now, for adapter layers, it is prefixed by
the adapter type, e.g. "lora.Linear". This should have no further
implications except for the repr (e.g. state_dict remains unaffected).

For LoHa and LoKr, I had to change the init of weights when using
init_weights=False. This is because of what is discussed in Numerical
instabilities with LoHa #1058.

IA³ now has the unload method too.

LoHa and LoKr now support safe_merge=True when merging layers.

Migration guide

For 99% of users, the code should continue working as ususal, because
the API stays the same. Only low level details have been changed.

Code that relies on isinstance checks on specific PEFT classes may
break. E.g. the LoRA Linear layer no longer inherits from nn.Linear. It
is, however, still a BaseTunerLayer. The same logic applies for other
layer types like Conv2d and for other tuners like IA³.

To retrieve the base layer of an adapter layer, you should now call
module.get_base_layer() if you deal with a BaseTunerLayer. Don't rely on
something like module.weight being present (though it might be).
@BenjaminBossan
Copy link
Member Author

Note: #1106 is merged, which contains the basic refactor on top of which mixed adapter models are built. I will probably create a new PR with just the mixed adapter code and close this one instead of trying to fix this up.

@BenjaminBossan
Copy link
Member Author

Closing in favor of #1163

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants