Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing multiple LoRA in the same batch or vit #1990

Merged
merged 10 commits into from
Sep 17, 2024
17 changes: 17 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,22 @@ output = peft_model.generate(**inputs, adapter_names=adapter_names, max_new_toke

Note that the order does not matter here, i.e. the samples in the batch don't need to be grouped by adapter as in the example above. We just need to ensure that the `adapter_names` argument is aligned correctly with the samples.

In certain scenarios, it is necessary to replicate specific neural network layers for each set of LoRA weights. A common use case is in classification tasks, where the classification head may need to be customized for each LoRA weight. The `modules_to_save` feature allows for the creation of multiple replicas of a given layer, each corresponding to a different set of LoRA weights. For example:
saeid93 marked this conversation as resolved.
Show resolved Hide resolved
```python
from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
from peft import LoraConfig

model_id = "facebook/opt-350m"
model = AutoModelForCausalLM.from_pretrained(model_id)

lora_config = LoraConfig(
target_modules=["q_proj", "k_proj"],
modules_to_save=["lm_head"],
)

model.add_adapter(lora_config)
```

### Caveats

Using this features has some drawbacks, namely:
Expand All @@ -382,3 +398,4 @@ Using this features has some drawbacks, namely:
- Increase the batch size.
- Try to avoid having a large number of different adapters in the same batch, prefer homogeneous batches. This can be achieved by buffering samples with the same adapter and only perform inference with a small handfull of different adapters.
- Take a look at alternative implementations such as [LoRAX](https://github.com/predibase/lorax), [punica](https://github.com/punica-ai/punica), or [S-LoRA](https://github.com/S-LoRA/S-LoRA), which are specialized to work with a large number of different adapters.
- The `modules_to_save` feature is currently only supported for the layers of types `Linear`, `Embedding`, `Conv2d` and `Conv1d`.
saeid93 marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties

from ...utils.other import ModulesToSaveWrapper
from .aqlm import dispatch_aqlm
from .awq import dispatch_awq
from .config import LoraConfig
Expand Down Expand Up @@ -432,7 +433,7 @@ def _enable_peft_forward_hooks(self, *args, **kwargs):

hook_handles = []
for module in self.modules():
if isinstance(module, LoraLayer):
if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)
Expand Down
59 changes: 55 additions & 4 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import warnings
from contextlib import nullcontext
from typing import Optional, Tuple
from typing import Any, Optional, Tuple

import accelerate
import torch
Expand Down Expand Up @@ -258,10 +258,61 @@ def _create_new_hook(self, old_hook):
new_hook = old_hook_cls(**filtered_old_hook_attr)
return new_hook

def forward(self, *args, **kwargs):
def _check_forward_args(self, x, *args, **kwargs):
"""Check if the arguments are compatible with the configs and state of the model"""
adapter_names = kwargs.get("adapter_names", None)
if adapter_names is None:
return

if len(x) != len(adapter_names):
msg = (
"Length of `adapter_names` should be the same as the number of inputs, but got "
f"{len(adapter_names)} and {len(x)} respectively."
)
raise ValueError(msg)

def _mixed_batch_forward(
self, input: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any
) -> torch.Tensor:
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
# extra argument that allows mixing different adapters in the same batch at inference time.

SUPPORTED_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, torch.nn.Conv1d)

if not isinstance(self.original_module, SUPPORTED_MODULES):
raise TypeError("Mixed batching is only supported for Linear, Embedding, Conv2d, and Conv1D modules.")
saeid93 marked this conversation as resolved.
Show resolved Hide resolved

unique_adapters = set(adapter_names)
sub_batch_indices_list = []

for adapter in unique_adapters:
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])

results = [0 for _ in range(len(input))]

for i, active_adapter in enumerate(unique_adapters):
sub_batch = input[sub_batch_indices_list[i]]

if active_adapter == "__base__":
output = self.original_module(sub_batch, *args, **kwargs)
else:
output = self.modules_to_save[active_adapter](sub_batch, *args, **kwargs)

for index, j in enumerate(sub_batch_indices_list[i]):
results[j] = output[index]

return torch.stack(results)

def forward(self, x: torch.Tensor, *args, **kwargs):
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)

if self.disable_adapters or (self.active_adapter not in self.modules_to_save):
return self.original_module(*args, **kwargs)
return self.modules_to_save[self.active_adapter](*args, **kwargs)
return self.original_module(x, *args, **kwargs)
if adapter_names is None:
return self.modules_to_save[self.active_adapter](x, *args, **kwargs)
else:
return self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
saeid93 marked this conversation as resolved.
Show resolved Hide resolved

def enable_adapters(self, enabled: bool):
"""Toggle the enabling and disabling of adapters
Expand Down
53 changes: 52 additions & 1 deletion tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,29 @@ def forward(self, X):
return X


class MLPWithGRU(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.lin0 = nn.Linear(10, 20, bias=bias)
self.relu = nn.ReLU()
self.drop = nn.Dropout(0.5)
self.gru = nn.GRU(input_size=20, hidden_size=20, num_layers=1, batch_first=True, bias=bias)
self.fc = nn.Linear(20, 2, bias=bias)
self.sm = nn.LogSoftmax(dim=-1)

def forward(self, X):
X = X.float()
X = self.lin0(X)
X = self.relu(X)
X = self.drop(X)
X = X.unsqueeze(1)
X, _ = self.gru(X)
X = X.squeeze(1)
X = self.fc(X)
X = self.sm(X)
return X


class MLP_LayerNorm(nn.Module):
def __init__(self, bias=True):
super().__init__()
Expand Down Expand Up @@ -3178,15 +3201,34 @@ def test_mixed_adapter_batches_lora_mlp(self, mlp_lora):

def test_mixed_adapter_batches_lora_different_target_layers(self, mlp_lora):
base_model = MLP().to(self.torch_device).eval()
# target different lora layers
config0 = LoraConfig(target_modules=["lin0"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["lin1"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(peft_model, inputs)

def test_mixed_adapter_batches_lora_different_classifiers(self, mlp_lora):
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
saeid93 marked this conversation as resolved.
Show resolved Hide resolved
base_model = MLP().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(peft_model, inputs)

def test_mixed_adapter_batches_lora_unsupported_layer(self, mlp_lora):
saeid93 marked this conversation as resolved.
Show resolved Hide resolved
base_model = MLPWithGRU().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["gru"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["lin0"], modules_to_save=["gru"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
with pytest.raises(
TypeError, match="Mixed batching is only supported for Linear, Embedding, Conv2d, and Conv1D modules."
):
self.run_checks(peft_model, inputs)

def test_mixed_adapter_batches_lora_partly_overlapping_target_layers(self, mlp_lora):
base_model = MLP().to(self.torch_device).eval()
# target different lora layers
Expand All @@ -3208,6 +3250,15 @@ def test_mixed_adapter_batches_lora_conv1d_emb(self):
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(peft_model, inputs)

def test_mixed_adapter_batches_lora_conv1d_emb_different_classifiers(self):
saeid93 marked this conversation as resolved.
Show resolved Hide resolved
base_model = ModelEmbConv1D().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["emb", "conv1d"], modules_to_save=["lin0"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["emb", "conv1d"], modules_to_save=["lin0"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(peft_model, inputs)

def test_mixed_adapter_batches_lora_conv2d(self):
base_model = ModelConv2D().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["conv2d"], init_lora_weights=False)
Expand Down
Loading