-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
unnecessarily slow get_peft_model() with LoRA (esp. in LLMs) #871
Comments
I noticed that These are some of the functions called inside BaseTuner.inject_adapter() # tuner_utils.py:30
LoraModel._create_and_replace(): # lora.py:320
_create_new_module(): # lora.py:415
new_module = lora.Linear() # lora.py:834 # class Linear(nn.Linear, LoraLayer)
nn.Linear.__init__(self, in_features, out_features, **kwargs) # init Linear layer with kaiming call
create empty weight matrix
call nn.Linear.reset_parameters() # UNNECESSARY - these weights will be scrapped later in _replace_module()
call kaiming_init
LoraLayer.__init__(self, in_features=in_features, out_features=out_features) # just create lora placeholders
nn.Linear.reset_parameters(self)
another call of kaiming_init # REDUNDANT
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
2 calls of Linear init for LoRA matrices (incl kaiming_init)
_replace_module()
new_module.weight = OLD_MODULE.weight # past Linear initializations are not relevant for they are overridden here
# dispatch LoRA layer modules to correct device Apparently the reasons for long execution are:
|
Thank you for reporting this and for digging into the underlying reasons. As you mentioned, there is probably room for improvement here and we'll take a look at your suggestion and what else we can do to speed things up. Edit Just now saw your PR, we'll take a look, thanks. |
How this story ended: |
@BenjaminBossan @pacman100 I don't think this issue is resolved yet -- or at least there's a highly related issue. The slowness caused by unnecessary inits also occurs in
|
I discovered one possible partial fix to this, which is to add |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
Update: The latest PEFT version released today (v0.6.0) includes the speed improvement for bnb LoRA layers. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
Note that we should now pretty much cover all cases, resulting in much faster initialization overall (bnb layers, adapters other than LoRA, merging), as can be seen from the state of progress tracked in #896. I'll thus close the issue. Feel free to re-open if you come across a case that is still slower than it should be. |
Message ID: ***@***.***>Thanks so much!
|
System Info
peft 0.6.0. dev0
Who can help?
@pacman100
@BenjaminBossan
Information
Tasks
Reproduction
just measure time it takes to run get_peft_model() with any large LLM and massive peft_config (guanaco style). Or run this:
Expected behavior
get_peft_model() with LoRA raking much less time than loading the original model.
The text was updated successfully, but these errors were encountered: