Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Add] DoRA Embedding #2006

Merged
merged 7 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/peft/tuners/lora/dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,32 @@ def __repr__(self) -> str:
return "lora.dora." + rep


class DoraEmbeddingLayer(DoraLinearLayer):
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, embed_fn):
"""
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
output.
"""
lora_weight = (lora_A @ lora_B).T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, too bad we can't use the same approach as in the other lora layers but instead have to multiply the parameters directly. This may cause trouble in some situations like with FSDP. But still, it's better than not having support for embeddings at all.

Copy link
Contributor Author

@ariG23498 ariG23498 Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree!

I tried with a linear layer to mimic the embedding weights, but the code was getting a little complicated.

magnitude = self.weight
weight = base_layer.weight
weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling)
# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353)
# "[...] we suggest treating ||V +∆V ||_c in
# Eq. (5) as a constant, thereby detaching it from the gradient
# graph. This means that while ||V + ∆V ||_c dynamically
# reflects the updates of ∆V , it won’t receive any gradient
# during backpropagation"
weight_norm = weight_norm.detach()
mag_norm_scale = magnitude / weight_norm
result_dora = mag_norm_scale * (embed_fn(x, lora_A) @ lora_B) * scaling
return mag_norm_scale, result_dora

def __repr__(self) -> str:
rep = super().__repr__()
return "lora.dora." + rep


class DoraConv2dLayer(DoraLinearLayer):
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
Expand Down
43 changes: 37 additions & 6 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from peft.utils.other import transpose

from .config import LoraConfig
from .dora import DoraConv2dLayer, DoraLinearLayer
from .dora import DoraConv2dLayer, DoraEmbeddingLayer, DoraLinearLayer


class LoraLayer(BaseTunerLayer):
Expand Down Expand Up @@ -594,9 +594,6 @@ def __init__(
super().__init__()
LoraLayer.__init__(self, base_layer)

if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

self._active_adapter = adapter_name
self.update_layer(
adapter_name,
Expand Down Expand Up @@ -635,9 +632,31 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)

# call this before dora_init
self._move_adapter_to_device_of_base_layer(adapter_name)

if use_dora:
self.dora_init(adapter_name)
self.use_dora[adapter_name] = True
else:
self.use_dora[adapter_name] = False

self.set_adapter(self.active_adapters)

def dora_init(self, adapter_name: str) -> None:
if self.lora_magnitude_vector is None:
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",)

dora_layer = DoraEmbeddingLayer(fan_in_fan_out=True)
lora_embedding_A = self.lora_embedding_A[adapter_name]
lora_embedding_B = self.lora_embedding_B[adapter_name]
scaling = self.scaling[adapter_name]
dora_layer.update_layer(
base_layer=self.get_base_layer(), lora_A=lora_embedding_A, lora_B=lora_embedding_B, scaling=scaling
)
self.lora_magnitude_vector[adapter_name] = dora_layer

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Expand Down Expand Up @@ -785,8 +804,20 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
embedding_A = self.lora_embedding_A[active_adapter].T
embedding_B = self.lora_embedding_B[active_adapter].T
scaling = self.scaling[active_adapter]
after_A = self._embed(x, embedding_A)
result = result + (after_A @ embedding_B) * scaling

if not self.use_dora[active_adapter]:
after_A = self._embed(x, embedding_A)
result = result + (after_A @ embedding_B) * scaling
else:
mag_norm_scale, dora_result = self.lora_magnitude_vector[active_adapter](
x,
lora_A=embedding_A,
lora_B=embedding_B,
scaling=scaling,
base_layer=self.get_base_layer(),
embed_fn=self._embed,
)
result = mag_norm_scale * result + dora_result
result = result.to(torch_result_dtype)

return result
Expand Down
21 changes: 19 additions & 2 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@
("Embedding + transformers Conv1D 1 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["conv1d"]}),
("Embedding + transformers Conv1D 2 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb"]}),
("Embedding + transformers Conv1D 3 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb", "conv1d"]}),
(
"Embedding + transformers Conv1D 1 DoRA",
"EmbConv1D",
LoraConfig,
{"target_modules": ["conv1d"], "use_dora": True},
),
("Embedding + transformers Conv1D 2 DoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb"], "use_dora": True}),
(
"Embedding + transformers Conv1D 3 DoRA",
"EmbConv1D",
LoraConfig,
{"target_modules": ["emb", "conv1d"], "use_dora": True},
),
("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}),
("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 1 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}),
Expand Down Expand Up @@ -874,7 +887,9 @@ def test_only_params_are_updated(self, test_name, model_id, config_cls, config_k
model_before = copy.deepcopy(model)

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
# this high learning rate was found through testing to be necessary to avoid flakiness
lr = 100.0 if config_kwargs.get("use_dora") and model_id == "EmbConv1D" else 0.5
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
# breaking of some LoRA layers that are initialized with constants)
Expand Down Expand Up @@ -943,7 +958,6 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c
def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
X = self.prepare_inputs_for_testing()
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device).eval()

outputs_base = model(**X)
if issubclass(config_cls, FourierFTConfig):
config_kwargs = config_kwargs.copy()
Expand Down Expand Up @@ -1046,6 +1060,9 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co
if issubclass(config_cls, IA3Config) and model_id == "Conv2d": # more instability with Conv2d + IA3
atol, rtol = 1e-3, 1e-3

if config_kwargs.get("use_dora") and model_id == "EmbConv1D":
atol, rtol = 1e-4, 1e-4

# check that there is a difference in results after training
assert not torch.allclose(outputs_before, outputs_after, atol=atol, rtol=rtol)

Expand Down
Loading