-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fixing multiple LoRA in the same batch or vit #1990
Changes from 2 commits
2579b85
fd0a9ce
6b0290f
d143b13
a46ad62
d3bce93
60384bd
683da8b
e0a12b3
bed1a10
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -261,7 +261,27 @@ def _create_new_hook(self, old_hook): | |
def forward(self, *args, **kwargs): | ||
if self.disable_adapters or (self.active_adapter not in self.modules_to_save): | ||
return self.original_module(*args, **kwargs) | ||
return self.modules_to_save[self.active_adapter](*args, **kwargs) | ||
if "adapter_names" not in kwargs.keys(): | ||
return self.modules_to_save[self.active_adapter](*args, **kwargs) | ||
# Batches requests with similar LoRAs into microbatches | ||
adapter_names = kwargs["adapter_names"] | ||
kwargs = {} | ||
batch = args[0] # Get the batch dimension | ||
unique_adapters = set(adapter_names) | ||
sub_batch_indices_list = [] | ||
for adapter in unique_adapters: | ||
sub_batch_indices_list.append( | ||
[index for index, item in enumerate(adapter_names) if item == adapter] | ||
) | ||
|
||
results = [0 for i in range(len(batch))] | ||
for i, active_adapter in enumerate(unique_adapters): | ||
sub_batch = batch[sub_batch_indices_list[i]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, here we assume that there is only 1 One suggestion that I have: Check all There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed the input definition in the new version with x as the input to avoid the problems that you mentioned. |
||
output = self.modules_to_save[active_adapter](*(sub_batch,), **kwargs) | ||
for index, j in enumerate(sub_batch_indices_list[i]): | ||
results[j] = output[index] | ||
return torch.stack(results) | ||
|
||
|
||
def enable_adapters(self, enabled: bool): | ||
"""Toggle the enabling and disabling of adapters | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to a sub-method, similar to how we do this for LoRA:
peft/src/peft/tuners/lora/layer.py
Line 327 in 4611034
Also, with this added, I think it makes sense to have a similar method as in LoRA to check the arguments:
peft/src/peft/tuners/lora/layer.py
Line 302 in 4611034
Of course, we have to be careful not to be too restrictive here, given the other issue that you raised, and since the underlying module could be of any type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both of the functions are added in the new commit, please check that.