-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Add] DoRA Embedding #2006
[Add] DoRA Embedding #2006
Changes from 2 commits
4aec4db
834c3dc
7cd3ee3
abf5f0b
300c641
0d9dfe7
bb35cae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ | |
from peft.utils.other import transpose | ||
|
||
from .config import LoraConfig | ||
from .dora import DoraConv2dLayer, DoraLinearLayer | ||
from .dora import DoraConv2dLayer, DoraEmbeddingLayer, DoraLinearLayer | ||
|
||
|
||
class LoraLayer(BaseTunerLayer): | ||
|
@@ -590,9 +590,6 @@ def __init__( | |
super().__init__() | ||
LoraLayer.__init__(self, base_layer) | ||
|
||
if use_dora: | ||
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") | ||
|
||
self._active_adapter = adapter_name | ||
self.update_layer( | ||
adapter_name, | ||
|
@@ -631,9 +628,31 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig | |
elif init_lora_weights: | ||
self.reset_lora_parameters(adapter_name, init_lora_weights) | ||
|
||
# call this before dora_init | ||
self._move_adapter_to_device_of_base_layer(adapter_name) | ||
|
||
if use_dora: | ||
self.dora_init(adapter_name) | ||
self.use_dora[adapter_name] = True | ||
else: | ||
self.use_dora[adapter_name] = False | ||
|
||
self.set_adapter(self.active_adapters) | ||
|
||
def dora_init(self, adapter_name: str) -> None: | ||
if self.lora_magnitude_vector is None: | ||
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters | ||
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",) | ||
|
||
dora_layer = DoraEmbeddingLayer(fan_in_fan_out=True) | ||
lora_embedding_A = self.lora_embedding_A[adapter_name] | ||
lora_embedding_B = self.lora_embedding_B[adapter_name] | ||
scaling = self.scaling[adapter_name] | ||
dora_layer.update_layer( | ||
base_layer=self.get_base_layer(), lora_A=lora_embedding_A, lora_B=lora_embedding_B, scaling=scaling | ||
) | ||
self.lora_magnitude_vector[adapter_name] = dora_layer | ||
|
||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: | ||
""" | ||
Merge the active adapter weights into the base weights | ||
|
@@ -781,8 +800,19 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: | |
embedding_A = self.lora_embedding_A[active_adapter].T | ||
embedding_B = self.lora_embedding_B[active_adapter].T | ||
scaling = self.scaling[active_adapter] | ||
after_A = self._embed(x, embedding_A) | ||
result = result + (after_A @ embedding_B) * scaling | ||
|
||
if not self.use_dora[active_adapter]: | ||
after_A = self._embed(x, embedding_A) | ||
result = result + (after_A @ embedding_B) * scaling | ||
else: | ||
result = result + self.lora_magnitude_vector[active_adapter]( | ||
x, | ||
lora_A=embedding_A, | ||
lora_B=embedding_B, | ||
scaling=scaling, | ||
base_layer=self.get_base_layer(), | ||
embed_fn=self._embed, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the previous commits, the DoRA embedding layer did not forward propagate the inputs. With this code snippet, we are not forward propagating the inputs through the DoRAEMbedding Layer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BenjaminBossan for visibility. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't understand. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I apologise for not being able to communicate well. I meant, the code snippet added in this commit (the one that is linked) now lets the inputs to propagate into the dora embedding layer. Previous to this, the dora embedding layer was being created, but not used. The inputs were being forward propagated into the Dora Linear Layer (as it was the parent class). This also means that the results you have noted in this comment should be re-run in order to make use of the embedding adapters properly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
result = result.to(torch_result_dtype) | ||
|
||
return result | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, too bad we can't use the same approach as in the other lora layers but instead have to multiply the parameters directly. This may cause trouble in some situations like with FSDP. But still, it's better than not having support for embeddings at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree!
I tried with a linear layer to mimic the embedding weights, but the code was getting a little complicated.