Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing multiple LoRA in the same batch or vit #1990

Merged
merged 10 commits into from
Sep 17, 2024
4 changes: 4 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
set_peft_model_state_dict,
shift_tokens_right,
)
from .utils.patches import SUPPORTED_MULTILORA_MULTIHEAD_MAP, patch_multi_lora_forward


PEFT_TYPE_TO_MODEL_MAPPING = {
Expand Down Expand Up @@ -767,6 +768,9 @@ def forward(self, *args: Any, **kwargs: Any):
Forward pass of the model.
"""
with self._enable_peft_forward_hooks(*args, **kwargs):
# Patch the forward functiona dynamically to support LoRA weights
if "adapter_names" in kwargs.keys() and self.get_base_model().__class__ in SUPPORTED_MULTILORA_MULTIHEAD_MAP.keys():
return patch_multi_lora_forward(self.get_base_model())(*args, **kwargs)
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
return self.get_base_model()(*args, **kwargs)

Expand Down
22 changes: 21 additions & 1 deletion src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,27 @@ def _create_new_hook(self, old_hook):
def forward(self, *args, **kwargs):
if self.disable_adapters or (self.active_adapter not in self.modules_to_save):
return self.original_module(*args, **kwargs)
return self.modules_to_save[self.active_adapter](*args, **kwargs)
if "adapter_names" not in kwargs.keys():
return self.modules_to_save[self.active_adapter](*args, **kwargs)
# Batches requests with similar LoRAs into microbatches
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to a sub-method, similar to how we do this for LoRA:

def _mixed_batch_forward(

Also, with this added, I think it makes sense to have a similar method as in LoRA to check the arguments:

def _check_forward_args(self, x, *args, **kwargs):

Of course, we have to be careful not to be too restrictive here, given the other issue that you raised, and since the underlying module could be of any type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of the functions are added in the new commit, please check that.

adapter_names = kwargs["adapter_names"]
kwargs = {}
batch = args[0] # Get the batch dimension
unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
sub_batch_indices_list.append(
[index for index, item in enumerate(adapter_names) if item == adapter]
)

results = [0 for i in range(len(batch))]
for i, active_adapter in enumerate(unique_adapters):
sub_batch = batch[sub_batch_indices_list[i]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, here we assume that there is only 1 args, as any other args would be dropped, right? Also, what if other args or kwargs need to be sliced? We don't really know that so I think the best we can do is make a guess.

One suggestion that I have:

Check all args and kwargs if they're tensors and if they are a tensor, that they have the same length (i.e. batch size). In that case, slice those too. Otherwise, leave them as is. It's not perfect but I'm not sure what else could be done. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the input definition in the new version with x as the input to avoid the problems that you mentioned.

output = self.modules_to_save[active_adapter](*(sub_batch,), **kwargs)
for index, j in enumerate(sub_batch_indices_list[i]):
results[j] = output[index]
return torch.stack(results)


def enable_adapters(self, enabled: bool):
"""Toggle the enabling and disabling of adapters
Expand Down
108 changes: 108 additions & 0 deletions src/peft/utils/patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union

import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import ViTForImageClassification
from transformers.modeling_outputs import ImageClassifierOutput

from . import PeftType


class ViTForImageClassificationFixed(ViTForImageClassification):
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

outputs = self.vit(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

sequence_output = outputs[0]

if "adapter_names" in kwargs.keys():
# This is changed to support adapter names
logits = self.classifier(
sequence_output[:, 0, :], adapter_names=kwargs["adapter_names"]
)
else:
logits = self.classifier(sequence_output[:, 0, :])

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

SUPPORTED_MULTILORA_MULTIHEAD_MAP = {
ViTForImageClassification: ViTForImageClassificationFixed
}


def patch_multi_lora_forward(model: PeftType.LORA) -> PeftType.LORA:
model.forward = SUPPORTED_MULTILORA_MULTIHEAD_MAP[model.__class__].forward.__get__(model)
return model