-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ability to switch between different adapter types and mix different adapter types #1006
Comments
Thanks for bringing up this feature request. Indeed, it's also something that we have discussed internally and which we agree would be good to have but will be hard to implement. Partly, the effect can be achieved already by merging and unloading one adapter type before adding the next, but of course this makes it impossible to unmerge the former, so we should find a better solution. I think that full support for this feature will be very difficult to achieve because of various assumptions that are made (implicitly) throughout the code base. Some types of adapters may even not be possible to combine. However, if we restrict this feature to a subset of adapters (e.g. LoRA, LoHa, LoKr, maybe IA³), it could be possible. One very big issue I see right now is the way we handle the control of which adapters are active. Just as an example, here is the meat of the forward method of LoRA peft/src/peft/tuners/lora/layer.py Lines 275 to 283 in eced2ed
As is, this layer "controls" which are the active adapters by applying them one after the other. However, if we want the active adapters to be One example of how this may work is how I prototyped it here. The general idea is to implement all adapters purely through forward hooks (pre and post). There is a wrapper class (like Of course, making such a switch would be an enormous task for a code base such as PEFT, so I'm not sure if it's not too late at this point. Maybe there are other ideas how to achieve this, I'm very open to suggestions. |
Thanks for sharing your view on this problem. Your prototype looks neat! I am wondering, how it compares to the current PEFT approach in terms of performance and reliability. I thought that attaching hooks to torch modules is more suitable for debugging / profiling purposes, but it provides so much flexibility for those adapter approaches
Based on my experience, I see that for Stable Diffusion the most used adapters right now are probably LoRA, LoHa, and LoKr. So this subset seems to cover most of the functionality needed for those models. Do you by chance know, is it a common thing to mix different adapters for LLMs? I suppose that there might exist some language-specific adapters, which could be combined with some domain-specific knowledge adapters, but I am now sure, what are the most common adapter types for that kind of tasks..
Hmmm, if I understand it correctly, if we restrict a mixture to a subset of LoRA, LoHa, and LoKr - these adapters are fully commutative (if we apply them at the same time), each of them in general just provides a Also, we can see that webui provides a similar approach - it just accumulates total
Your approach leads me to an idea - why don't we try to mimic this behavior in the existing code base? The key thing that needs to be done - is to separate actual adapter delta modules from layers that we are trying to modify. We might transition to something like this (pseudocode): class DeltaProviderProtocol(nn.Module):
def __init__(self):
...
def get_delta_weight(self):
...
def forward(self, x):
...
class LinearLoraDeltaProvider(DeltaProviderProtocol):
def __init__(self):
self.lora_A = nn.Linear(...)
self.lora_B = nn.Linear(...)
...
def get_delta_weight(self):
return self.lora_B.weight @ self.lora_A.weight
def forward(self, x):
return self.lora_B(self.lora_A(x))
class LinearLohaDeltaProvider(DeltaProviderProtocol):
def __init__(self):
self.hada_w1 = nn.Parameter(...)
self.hada_w2 = nn.Parameter(...)
...
def get_delta_weight(self):
return self.hada_w1 * self.hada_w2
def forward(self, x):
return self.get_delta_weight() @ x.T
class LinearLokrDeltaProvider(DeltaProviderProtocol):
def __init__(self):
self.lokr_w1 = nn.Parameter(...)
self.lokr_w2 = nn.Parameter(...)
...
def get_delta_weight(self):
return torch.kron(self.lokr_w1, self.lokr_w2)
def forward(self, x):
return self.get_delta_weight() @ x.T
class LinearAdapterLayer(nn.Linear):
def __init__(self, ...):
self.adapters: Dict[str, DeltaProviderProtocol] = nn.ModuleDict({})
...
def merge(self):
for adapter in self.active_adapters:
self.weight += self.adapters[adapter].get_delta_weight()
def forward(self, x):
result = F.linear(...)
for adapter if self.active_adapters:
result += self.adapters[adapter](x)
return result So, instead of modifying base models layers with the ones that support only single adapter ( In my opinion, it can be done without messing up with the existing code base (current adapters may exist alongside this new implementation, which will just utilize those In terms of enabling / mixing different adapters So, in general, this approach would work pretty much the same as your prototype but will allow us to reuse the existing code base and stay flexible. What do you think? |
I haven't tested that, since it was more of a proof of concept for me, but in general, I don't think that the performance characteristics should be different, as the same amount of computation is being carried out during forward/backward. It might be faster when initializing and switching adapters, but I'm not sure how noticeable that is.
I think at one point the PyTorch docs said so, but not anymore, so I think hooks are a good way of implementing this type of feature. (Note that
I'm not aware of this being a common pattern, but the field is moving so fast, so who knows what's true tomorrow.
Yes, but I'd be hesitant to "lock" the feature in a way that it only works with commutative modifications. Another solution that I discussed with @pacman100 is that we could refactor LoRA peft/src/peft/tuners/lora/layer.py Line 300 in 56556fa
which ignores the If, instead, we went ahead with the idea of adding a completely new adapter type which allows to combine multiple types of adapters, I think I'd like to try working with the hooks approach, which is a bit more flexible than requiring the adapter to provide a Note, however, that even with hooks, not all cases are covered. If we need to make a change in the middle for |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
Note: This will be implemented via #1069 or a spin-off of that PR. |
Feature request
Hi! As far as I know, currently, PEFT allows to load and infer only adapters of the same type.
So, we cannot e.g. load LoRA and LoHa for Stable Diffusion and switch between them during inference (we are forced to unload all the adapters of the first type, we may lose some progress, for example mixture of several LoRA adapters).
Also, it's a pretty common thing for Stable Diffusion ecosystem to load and mix different adapters together to get some unique style (so, adapters may be of different types also). This ability to mix different adapters together can be partially addressed with the new API for enabling multiple adapters during inference, but we are currently limited to a single adapter type.
To sum up, it would be great to have these features:
get_delta_weight
)Motivation
These features would be super useful for those who are using diffusers implementation of Stable Diffusion and PEFT for LoRAs/LoHas/etc. in production, where losing some progress or reloading checkpoints is undesirable.
As far as I know, Hugging Face diffusers do not allow to perform those types of manipulations with checkpoints, also PEFT has broader support for different adapters for SD&SDXL.
From my perspective, these features may also be useful for LLMs, probably an ability to switch between different adapters or mix different adapters may be beneficial for some downstream tasks.
Your contribution
I would be happy to help you with implementing these features, but it is not clear to me right now, how it could be achieved with current library architecture (so we probably need to discuss your view on how it should be implemented).
The text was updated successfully, but these errors were encountered: