Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unnecessarily slow get_peft_model() with LoRA (esp. in LLMs) #871

Closed
poedator opened this issue Aug 28, 2023 · 11 comments
Closed

unnecessarily slow get_peft_model() with LoRA (esp. in LLMs) #871

poedator opened this issue Aug 28, 2023 · 11 comments
Labels
enhancement New feature or request

Comments

@poedator
Copy link
Contributor

poedator commented Aug 28, 2023

System Info

peft 0.6.0. dev0

Who can help?

@pacman100
@BenjaminBossan

Information

Tasks

Reproduction

just measure time it takes to run get_peft_model() with any large LLM and massive peft_config (guanaco style). Or run this:

import peft
import torch
from torch import nn
import time

n = 8192
device = torch.device('cuda')

model1 = nn.Sequential(
    nn.Linear(n,n, device=device),
)

config1 = peft.LoraConfig(
    target_modules=['0'],
    lora_dropout = 0.
)

start = time.perf_counter()
model2 = peft.get_peft_model(model1, config1)
elapsed = time.perf_counter() - start

print(f"elapsed {elapsed:.3f}")

Expected behavior

get_peft_model() with LoRA raking much less time than loading the original model.

@poedator
Copy link
Contributor Author

poedator commented Aug 28, 2023

I noticed that get_peft_model() takes unnecessarily long time, which is quite annoying when working with LLMs. For instance, applying guanaco adapter to Llama-2 70B takes 15+ minutes.

These are some of the functions called inside get_peft_model():

BaseTuner.inject_adapter()  # tuner_utils.py:30
	LoraModel._create_and_replace():   # lora.py:320
		_create_new_module():  # lora.py:415
			new_module = lora.Linear()  # lora.py:834   # class Linear(nn.Linear, LoraLayer)
				nn.Linear.__init__(self, in_features, out_features, **kwargs)  						# init Linear layer with kaiming call
					create empty weight matrix
					call nn.Linear.reset_parameters()  # UNNECESSARY - these weights will be scrapped later in _replace_module()
						call kaiming_init 															
	        	LoraLayer.__init__(self, in_features=in_features, out_features=out_features)        # just create lora placeholders
	        	nn.Linear.reset_parameters(self)
	        		another call of kaiming_init  # REDUNDANT
		        self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
		        	2 calls of Linear init for LoRA matrices (incl kaiming_init)
		_replace_module()
			new_module.weight = OLD_MODULE.weight  # past Linear initializations are not relevant for they are overridden here
			# dispatch LoRA layer modules to correct device

Apparently the reasons for long execution are:

  • use of CPU to create a replacement layer (vs. original layer's device)
  • creation of whole new set of weights for replacement module requires call of kaiming_init() (which takes long on cpu).
  • unnecessary call of reset_parameters() (this is done by module constructor already)

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Aug 28, 2023

Thank you for reporting this and for digging into the underlying reasons. As you mentioned, there is probably room for improvement here and we'll take a look at your suggestion and what else we can do to speed things up.

Edit Just now saw your PR, we'll take a look, thanks.

@BenjaminBossan BenjaminBossan added the enhancement New feature or request label Aug 28, 2023
@poedator poedator changed the title unecessarily slow get_peft_model() with LoRA (esp. in LLMs) unnecessarily slow get_peft_model() with LoRA (esp. in LLMs) Aug 28, 2023
@poedator
Copy link
Contributor Author

How this story ended:
the PR #872 did not go through because a more elegant solution was found.
the solution for the slow init issue was implemented in these PRs: #887, #915 by @BenjaminBossan

@jph00
Copy link

jph00 commented Oct 10, 2023

@BenjaminBossan @pacman100 I don't think this issue is resolved yet -- or at least there's a highly related issue. The slowness caused by unnecessary inits also occurs in merge_and_unload, via this code path:

  File "/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/model.py", line 701, in merge_and_unload
    return self._unload_and_optionally_merge(progressbar=progressbar, safe_merge=safe_merge)
  File "/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/model.py", line 420, in _unload_and_optionally_merge
    new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 101, in __init__
    self.reset_parameters()
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 107, in reset_parameters
    init.kaiming_uniform_(self.weight, a=math.sqrt(5))

@jph00
Copy link

jph00 commented Oct 10, 2023

I discovered one possible partial fix to this, which is to add device=target.weight.device in _unload_and_optionally_merge, in the call new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias). By putting it on the device, the unnecessary init is at least accelerated.

@BenjaminBossan
Copy link
Member

@jph00 Indeed there are still parts of the code that have not been optimized for faster init yet. We track the progress in #896. There is already a PR for bnb layers (#994). Also thanks for bringing up merge_and_unload, I added it to the list of TODOs.

Copy link

github-actions bot commented Nov 3, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member

Update: The latest PEFT version released today (v0.6.0) includes the speed improvement for bnb LoRA layers.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member

Note that we should now pretty much cover all cases, resulting in much faster initialization overall (bnb layers, adapters other than LoRA, merging), as can be seen from the state of progress tracked in #896. I'll thus close the issue. Feel free to re-open if you come across a case that is still slower than it should be.

@jph00
Copy link

jph00 commented Nov 28, 2023 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants