Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster init with LoRA #872

Closed
wants to merge 1 commit into from
Closed

Conversation

poedator
Copy link
Contributor

@poedator poedator commented Aug 28, 2023

Purpose: accelerate peft model initialisation with LoRA config.
see issue description in #871

The idea is to avoid costly calls of weight initialisation functions in CPU, when those weights are not necessary and will be replaced by the original weights anyways.

This PR initially changes only init for Linear layers.

potential further improvements:

  • apply to Embedding and Conv2d layers
  • apply to compressed layer types, like Linear4bit, Linear8bit
  • init LoRA modules (A, B) right on target device, not CPU

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@BenjaminBossan
Copy link
Member

I wonder if we can use skip_init or something else from PyTorch instead of monkey patching a hard-coded list of initialization methods.

@poedator
Copy link
Contributor Author

I wonder if we can use skip_init or something else from PyTorch instead of monkey patching a hard-coded list of initialization methods.

well, I borrowed this idea from GPTQ code. But I support your idea to avoid monkey patching.
Let me try this improvement - I will write back soon.

@poedator
Copy link
Contributor Author

@BenjaminBossan !
skip_init() turned out not very convenient to apply from inside of a subclass. However, looking at its source code, I got an idea to use meta device. It works much faster than my original code, since meta tensors are faster than empty cpu tensors. Thanks for the suggestion!
To accomodate this change I also:
- made self.to(self.weight.device) in LoraLayer.update_layer() conditional, to avoid moving loras to 'meta' device
- removed self.weight.requires_grad = False in Linear.__init__() for it is not relevant. # lora.py:850

Now with Llama-2-7B application of Guanaco LoRA adapter takes 1.5 seconds vs 103 seconds originally.

@BenjaminBossan
Copy link
Member

@poedator Great, thanks for experimenting. This looks like a very nice speedup indeed. I'm not 100% sure if we can remove all of that code, but let's see. Could you please run make style on your code? Then the code quality checks should pass and the unit tests should run.

@poedator
Copy link
Contributor Author

poedator commented Aug 28, 2023

I added couple more commits to pass device param to LoRA layer constructors to initialize LoRA adapters right at the target device. Apparently I need to do some more local testing ...
As a matter of introduction: I am a junior researcher @ Yandex, 2d name in this SpQR paper on LLM quantization: https://arxiv.org/abs/2306.03078. We use HF packages daily and want them to get even better ).

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Aug 29, 2023

Thanks for applying the fixes and providing more context.

Unfortunately, the tests are failing right now. The main reason seems to be that the code path when init_lora_weights=False is not properly implemented yet, resulting in the weights staying on meta device. This needs to be fixed and it might require a bit of fiddling, as the logic is a bit all over the place. It would probably help if we cleaned up the initialization logic, but that requires some non-trivial refactoring.

To proceed, would you be up for fixing the failing tests for the code path I mentioned?

Edit: We merged #807, so there is a merge conflict now. But it is easy to fix. Please apply your changes to this file: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py instead of lora.py and you should be good.

@poedator
Copy link
Contributor Author

poedator commented Aug 29, 2023

@BenjaminBossan, thank you for your guidance on this PR.

I updated the PR to catch up with #807 and also limited my changes to just the Linear layer init to prove the concept.
Some of the tests with init_lora_weights=False are also failing locally (test_merge_layers_09 and 10). This is puzzling because they do not seem to involve LoRA for Linear. Nevertheless I want to see this PR coming thru, so let me try to fix it with your guidance.

Pls be more specific on the failing tests. do you want me to update the tests code or to change some LoRA code?
By the way, I re-ran the failed tests in main branch and they failed too. What could that be? test names:

  • test_lora_bnb_4bit_quantization_from_pretrained_safetensors
  • test_lora_causal_lm_mutli_gpu_inference
  • test_lora_seq2seq_lm_mutli_gpu_inference
  • test_merge_layers_09_Conv2d_1
  • test_merge_layers_10_Conv2d_2
    (Python 3.10.9, Ubuntu 18.04, A100, mucho RAM, accelerate 0.21.0, bitsandbytes 0.41.1, peft 0.6.0.dev0, torch 2.0.1, transformers 4.30.2)

On PR scope - let me know if you want me to include code for all of the other layer types. I wanted to do it initially but got confused by failed tests and scaled back.
Also I may try to pass device argument to LoRAs inits so they are created on Cuda where they belong (still faster). But that may be better left to a separate PR.

@BenjaminBossan
Copy link
Member

Thanks for the updates and the further investigation. Regarding the failing tests, the idea is that all existing tests should pass after the change. There could be tests that need adjustment, but I'm not aware of any. Anyway, for me, all the tests passed. As you can see, the CI passes too. So not sure why they failed for you.

I think the strategy you suggested, which is to start with one layer type first, is valid. We can do that and once that has proven to work well, we can tackle the others.

I did some further digging into the code and IIUC, I think the weight initialization happening in these two lines is completely unnecessary:

nn.Linear.__init__(self, in_features, out_features, **kwargs)

nn.Linear.reset_parameters(self)

It creates a randomly initialized self.weight (and possibly self.bias) on the lora layer, which is then initialized with kaiming init twice! And later, it is replaced by the original weight of the target layer, so the whole initialization was unnecessary:

new_module.weight = child.weight

Therefore, I think this step could be completely skipped. Since we're talking about the full sized weight here, not the smaller LoRA weights, this could take quite some time. The fix you proposed would create that weight on the meta device, but I wonder if we can rewrite the code to completely skip the initialization instead. If you want to explore that, feel free to do so, otherwise I'll take a look at it later. If you have some benchmark script to test the initialization speed, please share it.

Also I may try to pass device argument to LoRAs inits so they are created on Cuda where they belong (still faster). But that may be better left to a separate PR.

Yes, that's indeed best left for a separate PR.

@BenjaminBossan
Copy link
Member

@poedator I attempted to remove the redundant initializations completely in #887. If you have time to test potential speedups of that branch, it would be great.

@poedator
Copy link
Contributor Author

poedator commented Aug 30, 2023

@poedator I attempted to remove the redundant initializations completely in #887. If you have time to test potential speedups of that branch, it would be great.

@BenjaminBossan,
I just tested #887 speed, it is comparable to my PR and both are much faster than main branch.
see script and results here: https://gist.github.com/poedator/45331cd9c7837cbfea0ad578b3ce98ed

let me know if you prefer to proceed with #887 or with this PR. I admit that None is even better than tensor on meta device.

#872 passed the tests as of today. Apparently there is something in my local setup that causes errors when merging back the adapters for conv2d. I can investigate it further, but could you share versions of libs that you use for testing. My env is described above.

@BenjaminBossan
Copy link
Member

Thanks for providing the script @poedator. Could you please paste your measurements for reference? Unfortunately, I still haven't gotten access to Llama 2 (really not sure why it's still pending), so I can't test that model. For the others, I got:

# main, x3
test 1 with model bert-base took  0.186 sec.
test 2 with model bloomz-1b7 took 3.740 sec.
test 1 with model bert-base took  0.208 sec.
test 2 with model bloomz-1b7 took 3.889 sec.
test 1 with model bert-base took  0.187 sec.
test 2 with model bloomz-1b7 took 3.762 sec.

# PR 887, x3
test 1 with model bert-base took  0.019 sec.
test 2 with model bloomz-1b7 took 0.057 sec.
test 1 with model bert-base took  0.019 sec.
test 2 with model bloomz-1b7 took 0.030 sec.
test 1 with model bert-base took  0.019 sec.
test 2 with model bloomz-1b7 took 0.029 sec.

So this looks like a very decent speed up, ~10x, although in practice a lot of time is spent on loading the model in the first place, so overall clock time is only 11 vs 7 sec.

let me know if you prefer to proceed with #887 or with this PR.

My impression is that it would be even better not to initialize the weights at all instead of initializing them on meta device. Having useless inits is just causing confusion. The main advantage with initializing with meta is that the code changes are smaller, so they're less likely to accidentally break something. But from what I can tell, removing the inits should also not break anything.

@younesbelkada @pacman100 Do you have an opinion on that? How should we proceed?

@poedator
Copy link
Contributor Author

poedator commented Aug 31, 2023

Could you please paste your measurements for reference?
my results were at the top of the gist with code
Here I added 2 more tests with Llama 1 - 7B. the benefit is more notieceable with LLMs. Guanaco tests are slower because they involve Embedding layers, not covered by either our PRs.

Switched to branch 'pr887'       # PR #887 by Benjamin Bossan
test 1 with model bert-base took  0.048 sec.
test 2 with model bloomz-1b7 took 0.607 sec.
test 3  Llama-2-7B + Guanaco took 5.863 sec.
test 4 with model Llama-1-7B took 0.094 sec.
test 5  Llama-1-7B + Guanaco took 6.165 sec.

Switched to branch 'fast_init'   # PR #872 by poedator@
test 1 with model bert-base took  0.063 sec.      ~3x
test 2 with model bloomz-1b7 took 0.528 sec.      ~7x
test 3  Llama-2-7B + Guanaco took 5.618 sec.      ~15x
test 4 with model Llama-1-7B took 0.116 sec.      ~100x
test 5  Llama-1-7B + Guanaco took 6.273 sec.      ~12x

Switched to branch 'main'         # main version, commit 85013987aa82aa1af3da1236b6902556ce3e483e
test 1 with model bert-base took   0.224 sec.
test 2 with model bloomz-1b7 took  3.842 sec.
test 3  Llama-2-7B + Guanaco took 76.174 sec.  <---- terribly slow (vs 7 sec for loading whole 7B model)
test 4 with model Llama-1-7B took 11.607 sec.
test 5  Llama-1-7B + Guanaco took 77.542 sec.  <---- terribly slow

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work @poedator @BenjaminBossan thanks a lot for this! 🚀

@younesbelkada
Copy link
Contributor

The current way of proceeding looks great IMO as the code change is quite minimal for a nice gain, also I think meta device has been introduced in torch>=1.13 which is the current min requirement for PEFT: https://github.com/huggingface/peft/blob/main/setup.py#L45

@younesbelkada
Copy link
Contributor

Thinking a bit and discussing offline with @BenjaminBossan , there might be one problem with this approach if users uses LoraLayer as a standalone module, they might end up with weights being initialized on the meta device. This is fine as tests are passing but the module is public so there might be some users using it already, so let's keep that in mind
#887 will also make LoraLayer unusable as a standalone module as it will remove the weight attribute per my understanding from my discussion with Benjamin

@pacman100
Copy link
Contributor

Thank you @poedator and @BenjaminBossan for making the LoRA init super fast 🚀. I really like both approaches, favouring Benjamin's as it avoids the inits of weights altogether.

@regisss
Copy link

regisss commented Sep 7, 2023

Anything else missing for merging this PR?

@BenjaminBossan
Copy link
Member

Since #887 is merged, I think we can close this one. Progress for other layers/adapters is tracked in #896.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Sep 7, 2023
Partly resolves huggingface#872

Description

After getting faster initialization of the LoRA Linear layer,
initialization of Conv2D and Embedding is now sped up.

Implementation

The approach of how to achieve the speed up has slightly changed
compared to last time. To refresh memory, in huggingface#887, we avoided the
unnecessary initialization of the full weight matrix by completely
skipping nn.Linear.__init__.

Although it is possible to do the same for Embedding and Conv2d, we run
into some trouble here. The issue is that the __init__ methods of these
classes have quite a lot more arguments and some custom logic (i.e. not
only self.foo = foo but more on top). If we wanted to skip __init__
entirely, we would have to basically copy all of that into our code.
Although that is possible, it is brittle (e.g. the logic could be
different for different PyTorch versions or change over time).

For that reason, I opted to implement this differently, using a
suggestion we had discussed earlier. The approach is to call __init__ of
the parent class but enforce empty weights (this is what
torch.nn.utils.skip_init does, although we cannot use that function
directly). This way, we can avoid having to copy the __init__ code while
still avoiding expensive initialization of the weights.

I did not change the code for Linear to also use this approach because
the logic inside of Linear.__init__ is quite simple (at least for now),
so we are good here with the existing approach.

However, I was curious how changing the approach for Linear would affect
the initialization speed. Therefore, I ran the script from huggingface#872 again, 3
times each.

Current approach:

test 1 with model bert-base took 0.021 sec.
test 1 with model bert-base took 0.020 sec.
test 1 with model bert-base took 0.020 sec.
test 2 with model bloomz-1b7 took 0.030 sec.
test 2 with model bloomz-1b7 took 0.030 sec.
test 2 with model bloomz-1b7 took 0.030 sec.

New approach if applied to Linear:

test 1 with model bert-base took 0.038 sec.
test 1 with model bert-base took 0.039 sec.
test 1 with model bert-base took 0.038 sec.
test 2 with model bloomz-1b7 took 0.072 sec.
test 2 with model bloomz-1b7 took 0.048 sec.
test 2 with model bloomz-1b7 took 0.048 sec.

This shows that the new approach is indeed a bit slower than the
existing one, though still a lot faster than what we had before. IMHO, I
think we're safe to leave the code inside of Linear as is and benefit
from the slightly better performance at the cost of slightly more
fragile code. But please let me know if you prefer:

1. The new approach should also be applied to Linear
2. The existing approach should also be applied to Embedding and Conv2d
BenjaminBossan added a commit that referenced this pull request Sep 12, 2023
Partly resolves #872

Description

After getting faster initialization of the LoRA Linear layer,
initialization of Conv2D and Embedding is now sped up.

Implementation

The approach of how to achieve the speed up has slightly changed
compared to last time. To refresh memory, in #887, we avoided the
unnecessary initialization of the full weight matrix by completely
skipping nn.Linear.__init__.

Although it is possible to do the same for Embedding and Conv2d, we run
into some trouble here. The issue is that the __init__ methods of these
classes have quite a lot more arguments and some custom logic (i.e. not
only self.foo = foo but more on top). If we wanted to skip __init__
entirely, we would have to basically copy all of that into our code.
Although that is possible, it is brittle (e.g. the logic could be
different for different PyTorch versions or change over time).

For that reason, I opted to implement this differently, using a
suggestion we had discussed earlier. The approach is to call __init__ of
the parent class but enforce empty weights (this is what
torch.nn.utils.skip_init does, although we cannot use that function
directly). This way, we can avoid having to copy the __init__ code while
still avoiding expensive initialization of the weights.

I did not change the code for Linear to also use this approach because
the logic inside of Linear.__init__ is quite simple (at least for now),
so we are good here with the existing approach.

However, I was curious how changing the approach for Linear would affect
the initialization speed. Therefore, I ran the script from #872 again, 3
times each.

Current approach:

test 1 with model bert-base took 0.021 sec.
test 1 with model bert-base took 0.020 sec.
test 1 with model bert-base took 0.020 sec.
test 2 with model bloomz-1b7 took 0.030 sec.
test 2 with model bloomz-1b7 took 0.030 sec.
test 2 with model bloomz-1b7 took 0.030 sec.

New approach if applied to Linear:

test 1 with model bert-base took 0.038 sec.
test 1 with model bert-base took 0.039 sec.
test 1 with model bert-base took 0.038 sec.
test 2 with model bloomz-1b7 took 0.072 sec.
test 2 with model bloomz-1b7 took 0.048 sec.
test 2 with model bloomz-1b7 took 0.048 sec.

This shows that the new approach is indeed a bit slower than the
existing one, though still a lot faster than what we had before. IMHO, I
think we're safe to leave the code inside of Linear as is and benefit
from the slightly better performance at the cost of slightly more
fragile code.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants