Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing multiple LoRA in the same batch or vit #1990

Merged
merged 10 commits into from
Sep 17, 2024
3 changes: 2 additions & 1 deletion src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties

from ...utils.other import ModulesToSaveWrapper
from .aqlm import dispatch_aqlm
from .awq import dispatch_awq
from .config import LoraConfig
Expand Down Expand Up @@ -432,7 +433,7 @@ def _enable_peft_forward_hooks(self, *args, **kwargs):

hook_handles = []
for module in self.modules():
if isinstance(module, LoraLayer):
if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)
Expand Down
22 changes: 21 additions & 1 deletion src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,27 @@ def _create_new_hook(self, old_hook):
def forward(self, *args, **kwargs):
if self.disable_adapters or (self.active_adapter not in self.modules_to_save):
return self.original_module(*args, **kwargs)
return self.modules_to_save[self.active_adapter](*args, **kwargs)
if "adapter_names" not in kwargs.keys():
return self.modules_to_save[self.active_adapter](*args, **kwargs)
# Batches requests with similar LoRAs into microbatches
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to a sub-method, similar to how we do this for LoRA:

def _mixed_batch_forward(

Also, with this added, I think it makes sense to have a similar method as in LoRA to check the arguments:

def _check_forward_args(self, x, *args, **kwargs):

Of course, we have to be careful not to be too restrictive here, given the other issue that you raised, and since the underlying module could be of any type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of the functions are added in the new commit, please check that.

adapter_names = kwargs["adapter_names"]
kwargs = {}
batch = args[0] # Get the batch dimension
unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
sub_batch_indices_list.append(
[index for index, item in enumerate(adapter_names) if item == adapter]
)

results = [0 for i in range(len(batch))]
for i, active_adapter in enumerate(unique_adapters):
sub_batch = batch[sub_batch_indices_list[i]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, here we assume that there is only 1 args, as any other args would be dropped, right? Also, what if other args or kwargs need to be sliced? We don't really know that so I think the best we can do is make a guess.

One suggestion that I have:

Check all args and kwargs if they're tensors and if they are a tensor, that they have the same length (i.e. batch size). In that case, slice those too. Otherwise, leave them as is. It's not perfect but I'm not sure what else could be done. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the input definition in the new version with x as the input to avoid the problems that you mentioned.

output = self.modules_to_save[active_adapter](*(sub_batch,), **kwargs)
for index, j in enumerate(sub_batch_indices_list[i]):
results[j] = output[index]
return torch.stack(results)


def enable_adapters(self, enabled: bool):
"""Toggle the enabling and disabling of adapters
Expand Down