Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing multiple LoRA in the same batch or vit #1990

Merged
merged 10 commits into from
Sep 17, 2024
3 changes: 2 additions & 1 deletion src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties

from ...utils.other import ModulesToSaveWrapper
from .aqlm import dispatch_aqlm
from .awq import dispatch_awq
from .config import LoraConfig
Expand Down Expand Up @@ -432,7 +433,7 @@ def _enable_peft_forward_hooks(self, *args, **kwargs):

hook_handles = []
for module in self.modules():
if isinstance(module, LoraLayer):
if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)
Expand Down
48 changes: 44 additions & 4 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import warnings
from contextlib import nullcontext
from typing import Optional, Tuple
from typing import Any, Optional, Tuple

import accelerate
import torch
Expand Down Expand Up @@ -258,10 +258,50 @@ def _create_new_hook(self, old_hook):
new_hook = old_hook_cls(**filtered_old_hook_attr)
return new_hook

def forward(self, *args, **kwargs):
def _check_forward_args(self, x, *args, **kwargs):
"""Check if the arguments are compatible with the configs and state of the model"""
adapter_names = kwargs.get("adapter_names", None)
if adapter_names is None:
return

if len(x) != len(adapter_names):
msg = (
"Length of `adapter_names` should be the same as the number of inputs, but got "
f"{len(adapter_names)} and {len(x)} respectively."
)
raise ValueError(msg)

def _mixed_batch_forward(
self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any
saeid93 marked this conversation as resolved.
Show resolved Hide resolved
) -> torch.Tensor:
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
# extra argument that allows mixing different adapters in the same batch at inference time.
unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])

results = [0 for i in range(len(x))]
for i, active_adapter in enumerate(unique_adapters):
sub_batch = x[sub_batch_indices_list[i]]
if active_adapter == "__base__":
output = self.original_module(*(sub_batch,), **kwargs)
else:
output = self.modules_to_save[active_adapter](*(sub_batch,), **kwargs)
for index, j in enumerate(sub_batch_indices_list[i]):
results[j] = output[index]
return torch.stack(results)

def forward(self, x: torch.Tensor, *args, **kwargs):
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)

if self.disable_adapters or (self.active_adapter not in self.modules_to_save):
return self.original_module(*args, **kwargs)
return self.modules_to_save[self.active_adapter](*args, **kwargs)
return self.original_module(x, *args, **kwargs)
if adapter_names is None:
return self.modules_to_save[self.active_adapter](x, *args, **kwargs)
else:
return self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
saeid93 marked this conversation as resolved.
Show resolved Hide resolved

def enable_adapters(self, enabled: bool):
"""Toggle the enabling and disabling of adapters
Expand Down
9 changes: 8 additions & 1 deletion tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3178,12 +3178,19 @@ def test_mixed_adapter_batches_lora_mlp(self, mlp_lora):

def test_mixed_adapter_batches_lora_different_target_layers(self, mlp_lora):
base_model = MLP().to(self.torch_device).eval()
# target different lora layers
config0 = LoraConfig(target_modules=["lin0"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["lin1"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(peft_model, inputs)

def test_mixed_adapter_batches_lora_different_classifiers(self, mlp_lora):
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
saeid93 marked this conversation as resolved.
Show resolved Hide resolved
base_model = MLP().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["sm"], init_lora_weights=False)
saeid93 marked this conversation as resolved.
Show resolved Hide resolved
config1 = LoraConfig(target_modules=["lin1"], modules_to_save=["sm"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(peft_model, inputs)

Expand Down