-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Description
I am curious if I can update only one adapter in multi adapter siutations.
For example, I have two adapter A and B.
In forward pass, LLM + adapter A + adapter B to get Loss,
and update only adapter A using peft.
I found _mixed_batch_forward in your code, is it related to my question?
def _mixed_batch_forward(
self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any
) -> torch.Tensor:
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
# extra argument that allows mixing different adapters in the same batch at inference time.
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
for i, active_adapter in enumerate(unique_adapters):
if active_adapter == "__base__":
continue
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
sub_batch = x[sub_batch_indices_list[i]].to(lora_A.weight.dtype)
lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling
result[sub_batch_indices_list[i]] += lora_output.to(torch_result_dtype)
return result
Metadata
Metadata
Assignees
Labels
No labels