-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fixing multiple LoRA in the same batch or vit #1990
Changes from 1 commit
2579b85
fd0a9ce
6b0290f
d143b13
a46ad62
d3bce93
60384bd
683da8b
e0a12b3
bed1a10
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -261,7 +261,27 @@ def _create_new_hook(self, old_hook): | |
def forward(self, *args, **kwargs): | ||
if self.disable_adapters or (self.active_adapter not in self.modules_to_save): | ||
return self.original_module(*args, **kwargs) | ||
return self.modules_to_save[self.active_adapter](*args, **kwargs) | ||
if "adapter_names" not in kwargs.keys(): | ||
return self.modules_to_save[self.active_adapter](*args, **kwargs) | ||
# Batches requests with similar LoRAs into microbatches | ||
adapter_names = kwargs["adapter_names"] | ||
kwargs = {} | ||
batch = args[0] # Get the batch dimension | ||
unique_adapters = set(adapter_names) | ||
sub_batch_indices_list = [] | ||
for adapter in unique_adapters: | ||
sub_batch_indices_list.append( | ||
[index for index, item in enumerate(adapter_names) if item == adapter] | ||
) | ||
|
||
results = [0 for i in range(len(batch))] | ||
for i, active_adapter in enumerate(unique_adapters): | ||
sub_batch = batch[sub_batch_indices_list[i]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, here we assume that there is only 1 One suggestion that I have: Check all There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed the input definition in the new version with x as the input to avoid the problems that you mentioned. |
||
output = self.modules_to_save[active_adapter](*(sub_batch,), **kwargs) | ||
for index, j in enumerate(sub_batch_indices_list[i]): | ||
results[j] = output[index] | ||
return torch.stack(results) | ||
|
||
|
||
def enable_adapters(self, enabled: bool): | ||
"""Toggle the enabling and disabling of adapters | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright 2023-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Optional, Union | ||
|
||
import torch | ||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | ||
from transformers import ViTForImageClassification | ||
from transformers.modeling_outputs import ImageClassifierOutput | ||
|
||
from . import PeftType | ||
|
||
|
||
class ViTForImageClassificationFixed(ViTForImageClassification): | ||
def forward( | ||
self, | ||
pixel_values: Optional[torch.Tensor] = None, | ||
head_mask: Optional[torch.Tensor] = None, | ||
labels: Optional[torch.Tensor] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
interpolate_pos_encoding: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
**kwargs | ||
) -> Union[tuple, ImageClassifierOutput]: | ||
r""" | ||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | ||
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., | ||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | ||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | ||
""" | ||
return_dict = ( | ||
return_dict if return_dict is not None else self.config.use_return_dict | ||
) | ||
|
||
outputs = self.vit( | ||
pixel_values, | ||
head_mask=head_mask, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
interpolate_pos_encoding=interpolate_pos_encoding, | ||
return_dict=return_dict, | ||
) | ||
|
||
sequence_output = outputs[0] | ||
|
||
if "adapter_names" in kwargs.keys(): | ||
# This is changed to support adapter names | ||
logits = self.classifier( | ||
sequence_output[:, 0, :], adapter_names=kwargs["adapter_names"] | ||
) | ||
else: | ||
logits = self.classifier(sequence_output[:, 0, :]) | ||
|
||
loss = None | ||
if labels is not None: | ||
# move labels to correct device to enable model parallelism | ||
labels = labels.to(logits.device) | ||
if self.config.problem_type is None: | ||
if self.num_labels == 1: | ||
self.config.problem_type = "regression" | ||
elif self.num_labels > 1 and ( | ||
labels.dtype == torch.long or labels.dtype == torch.int | ||
): | ||
self.config.problem_type = "single_label_classification" | ||
else: | ||
self.config.problem_type = "multi_label_classification" | ||
|
||
if self.config.problem_type == "regression": | ||
loss_fct = MSELoss() | ||
if self.num_labels == 1: | ||
loss = loss_fct(logits.squeeze(), labels.squeeze()) | ||
else: | ||
loss = loss_fct(logits, labels) | ||
elif self.config.problem_type == "single_label_classification": | ||
loss_fct = CrossEntropyLoss() | ||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | ||
elif self.config.problem_type == "multi_label_classification": | ||
loss_fct = BCEWithLogitsLoss() | ||
if not return_dict: | ||
output = (logits,) + outputs[1:] | ||
return ((loss,) + output) if loss is not None else output | ||
|
||
return ImageClassifierOutput( | ||
loss=loss, | ||
logits=logits, | ||
hidden_states=outputs.hidden_states, | ||
attentions=outputs.attentions, | ||
) | ||
|
||
SUPPORTED_MULTILORA_MULTIHEAD_MAP = { | ||
ViTForImageClassification: ViTForImageClassificationFixed | ||
} | ||
|
||
|
||
def patch_multi_lora_forward(model: PeftType.LORA) -> PeftType.LORA: | ||
model.forward = SUPPORTED_MULTILORA_MULTIHEAD_MAP[model.__class__].forward.__get__(model) | ||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to a sub-method, similar to how we do this for LoRA:
peft/src/peft/tuners/lora/layer.py
Line 327 in 4611034
Also, with this added, I think it makes sense to have a similar method as in LoRA to check the arguments:
peft/src/peft/tuners/lora/layer.py
Line 302 in 4611034
Of course, we have to be careful not to be too restrictive here, given the other issue that you raised, and since the underlying module could be of any type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both of the functions are added in the new commit, please check that.