Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase speed of adapter layer initialization #896

Closed
5 of 6 tasks
BenjaminBossan opened this issue Sep 1, 2023 · 6 comments
Closed
5 of 6 tasks

Increase speed of adapter layer initialization #896

BenjaminBossan opened this issue Sep 1, 2023 · 6 comments

Comments

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Sep 1, 2023

Feature request

We are working on increasing the speed of adapter layer initialization. The first step towards this is #887, which tackles linear LoRA layers, which are arguably the most important ones. As discussed with @pacman100, after that PR, there are still some TODOs:

Motivation

Initial testing showed ~10x increase in speed with the optimized initialization.

Your contribution

#887 is the first step.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue Oct 5, 2023
Partly addresses huggingface#896

Description

After speeding up normal LoRA layer initialization, this PR improves
initialization speed of bnb LoRA layers.

The method to achieve this is different from the one used before, namely
this time the base layer is stored as a reference on the LoRA layer.
This allows us to avoid calling __init__ on the bnb layer, which is what
is slow.

Notes

We cannot use the same method as for the normal LoRA layers, (i.e.
calling the super class's __init__ with meta device) because the bnb
layers have extra logic that still creates unnecessary weights.

However, the way used here could also be a solution to the normal
layers, so if we want to have consistency, the normal layers could be
refactored to use the same approach.

Interestingly, even though we now save the base layer as a reference,
which results in a different state_dict, the existing models can still
be loaded successfully. This is because the adapter state_dict is not
affected by the change, so users can still load their existing adapters.

The only problem would occur if users dump the whole model, i.e. base
model and adapter, using torch.save and then trying to load with
torch.load. For those users, we could theoretically provide a script to
convert the state_dict (i.e. renaming some keys).

To ensure that the old adapters can still be loaded successfully, I'm
working at the same time on adding regression tests. I'll create a
separate PR for those to avoid blowing up this one.

Tests

I ran a test on bloomz-1b1 for how long it takes to create the
PeftModel, the results are:

8bit: 1108.34 ms > 26.82 ms
4bit: 1101.96 ms > 23.69 ms
BenjaminBossan added a commit that referenced this issue Oct 10, 2023
Partly addresses #896

Description

After speeding up normal LoRA layer initialization, this PR improves
initialization speed of bnb LoRA layers.

The method to achieve this is different from the one used before, namely
this time the base layer is stored as a reference on the LoRA layer.
This allows us to avoid calling __init__ on the bnb layer, which is what
is slow.

Notes

We cannot use the same method as for the normal LoRA layers, (i.e.
calling the super class's __init__ with meta device) because the bnb
layers have extra logic that still creates unnecessary weights.

However, the way used here could also be a solution to the normal
layers, so if we want to have consistency, the normal layers could be
refactored to use the same approach.

Interestingly, even though we now save the base layer as a reference,
which results in a different state_dict, the existing models can still
be loaded successfully. This is because the adapter state_dict is not
affected by the change, so users can still load their existing adapters.

The only problem would occur if users dump the whole model, i.e. base
model and adapter, using torch.save and then trying to load with
torch.load. For those users, we could theoretically provide a script to
convert the state_dict (i.e. renaming some keys).

To ensure that the old adapters can still be loaded successfully, I'm
working at the same time on adding regression tests. I'll create a
separate PR for those to avoid blowing up this one.

Tests

I ran a test on bloomz-1b1 for how long it takes to create the
PeftModel, the results are:

8bit: 1108.34 ms > 26.82 ms
4bit: 1101.96 ms > 23.69 ms
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member Author

Not stale

@osmalpkoras
Copy link

osmalpkoras commented Jun 17, 2024

Hi guys, I am trying to load a meta-llama/Meta-Llama-3-70B-Instruct PEFT model with

LoraConfig(
    init_lora_weights="pissa",
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules="all-linear"
)

and I noticed that it is very slow in initializing the layers (it takes hours). To be precise, it's the for key in key_list loop in peft.tuners.tuners_utils.BaseTuner.inject_adapter that is very slow.

Should this issue be actually fixed with this feature? Because if it should be, than it does not work for meta-llama/Meta-Llama-3-70B-Instruct with the above LoraConfig.

@BenjaminBossan
Copy link
Member Author

Thanks for reporting this. Is it still slow when you remove init_lora_weights="pissa" from the config?

@osmalpkoras
Copy link

Actually if I use the default method, it is very fast. Switching to init_lora_weights="pissa_niter_32" also made it very fast. I just notice that if I use this initialization with a multi-gpu deepspeed setup, it again takes a couple of hours, maybe due to some concurrency issues.

@BenjaminBossan
Copy link
Member Author

It is expected that using the default init or PiSSA with n_iter is faster than just pissa. However, a couple of hours still sounds excessive. Does it finish after a couple of hours or do you just cancel the process at that point?

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue Jun 25, 2024
Resolves huggingface/accelerate#2886

Possibly resolves
huggingface#896 (comment)

Some LoRA init methods need to access the base layer weight. Getting
this access can fail or stall in distributed settings. For DeepSpeed,
the weight is now gathered before trying to access it.

Note: Without DeepSpeed, this is a no-op and should thus not have any
disadvantage. We don't have DS in our CI, so this is not tested.

I also made some small changes to OLoRA init to use
self.get_base_layer() instead of self.base_layer.
BenjaminBossan added a commit that referenced this issue Jun 26, 2024
Resolves huggingface/accelerate#2886

Possibly resolves
#896 (comment)

Some LoRA init methods need to access the base layer weight. Getting
this access can fail or stall in distributed settings. For DeepSpeed,
the weight is now gathered before trying to access it.

Note: Without DeepSpeed, this is a no-op and should thus not have any
disadvantage. We don't have DS in our CI, so this is not tested.

I also made some small changes to OLoRA init to use
self.get_base_layer() instead of self.base_layer.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants