Skip to content

Commit

Permalink
[Add] DoRA Embedding (#2006)
Browse files Browse the repository at this point in the history
  • Loading branch information
ariG23498 authored Aug 23, 2024
1 parent c3b63ce commit 900f96c
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 8 deletions.
26 changes: 26 additions & 0 deletions src/peft/tuners/lora/dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,32 @@ def __repr__(self) -> str:
return "lora.dora." + rep


class DoraEmbeddingLayer(DoraLinearLayer):
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, embed_fn):
"""
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
output.
"""
lora_weight = (lora_A @ lora_B).T
magnitude = self.weight
weight = base_layer.weight
weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling)
# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353)
# "[...] we suggest treating ||V +∆V ||_c in
# Eq. (5) as a constant, thereby detaching it from the gradient
# graph. This means that while ||V + ∆V ||_c dynamically
# reflects the updates of ∆V , it won’t receive any gradient
# during backpropagation"
weight_norm = weight_norm.detach()
mag_norm_scale = magnitude / weight_norm
result_dora = mag_norm_scale * (embed_fn(x, lora_A) @ lora_B) * scaling
return mag_norm_scale, result_dora

def __repr__(self) -> str:
rep = super().__repr__()
return "lora.dora." + rep


class DoraConv2dLayer(DoraLinearLayer):
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
Expand Down
43 changes: 37 additions & 6 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from peft.utils.other import transpose

from .config import LoraConfig
from .dora import DoraConv2dLayer, DoraLinearLayer
from .dora import DoraConv2dLayer, DoraEmbeddingLayer, DoraLinearLayer


class LoraLayer(BaseTunerLayer):
Expand Down Expand Up @@ -594,9 +594,6 @@ def __init__(
super().__init__()
LoraLayer.__init__(self, base_layer)

if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

self._active_adapter = adapter_name
self.update_layer(
adapter_name,
Expand Down Expand Up @@ -635,9 +632,31 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)

# call this before dora_init
self._move_adapter_to_device_of_base_layer(adapter_name)

if use_dora:
self.dora_init(adapter_name)
self.use_dora[adapter_name] = True
else:
self.use_dora[adapter_name] = False

self.set_adapter(self.active_adapters)

def dora_init(self, adapter_name: str) -> None:
if self.lora_magnitude_vector is None:
# first dora layer being added, add lora_magnitude_vector to the list of learnable parameters
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",)

dora_layer = DoraEmbeddingLayer(fan_in_fan_out=True)
lora_embedding_A = self.lora_embedding_A[adapter_name]
lora_embedding_B = self.lora_embedding_B[adapter_name]
scaling = self.scaling[adapter_name]
dora_layer.update_layer(
base_layer=self.get_base_layer(), lora_A=lora_embedding_A, lora_B=lora_embedding_B, scaling=scaling
)
self.lora_magnitude_vector[adapter_name] = dora_layer

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Expand Down Expand Up @@ -785,8 +804,20 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
embedding_A = self.lora_embedding_A[active_adapter].T
embedding_B = self.lora_embedding_B[active_adapter].T
scaling = self.scaling[active_adapter]
after_A = self._embed(x, embedding_A)
result = result + (after_A @ embedding_B) * scaling

if not self.use_dora[active_adapter]:
after_A = self._embed(x, embedding_A)
result = result + (after_A @ embedding_B) * scaling
else:
mag_norm_scale, dora_result = self.lora_magnitude_vector[active_adapter](
x,
lora_A=embedding_A,
lora_B=embedding_B,
scaling=scaling,
base_layer=self.get_base_layer(),
embed_fn=self._embed,
)
result = mag_norm_scale * result + dora_result
result = result.to(torch_result_dtype)

return result
Expand Down
21 changes: 19 additions & 2 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@
("Embedding + transformers Conv1D 1 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["conv1d"]}),
("Embedding + transformers Conv1D 2 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb"]}),
("Embedding + transformers Conv1D 3 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb", "conv1d"]}),
(
"Embedding + transformers Conv1D 1 DoRA",
"EmbConv1D",
LoraConfig,
{"target_modules": ["conv1d"], "use_dora": True},
),
("Embedding + transformers Conv1D 2 DoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb"], "use_dora": True}),
(
"Embedding + transformers Conv1D 3 DoRA",
"EmbConv1D",
LoraConfig,
{"target_modules": ["emb", "conv1d"], "use_dora": True},
),
("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}),
("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 1 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}),
Expand Down Expand Up @@ -874,7 +887,9 @@ def test_only_params_are_updated(self, test_name, model_id, config_cls, config_k
model_before = copy.deepcopy(model)

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
# this high learning rate was found through testing to be necessary to avoid flakiness
lr = 100.0 if config_kwargs.get("use_dora") and model_id == "EmbConv1D" else 0.5
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
# breaking of some LoRA layers that are initialized with constants)
Expand Down Expand Up @@ -943,7 +958,6 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c
def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
X = self.prepare_inputs_for_testing()
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device).eval()

outputs_base = model(**X)
if issubclass(config_cls, FourierFTConfig):
config_kwargs = config_kwargs.copy()
Expand Down Expand Up @@ -1046,6 +1060,9 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co
if issubclass(config_cls, IA3Config) and model_id == "Conv2d": # more instability with Conv2d + IA3
atol, rtol = 1e-3, 1e-3

if config_kwargs.get("use_dora") and model_id == "EmbConv1D":
atol, rtol = 1e-4, 1e-4

# check that there is a difference in results after training
assert not torch.allclose(outputs_before, outputs_after, atol=atol, rtol=rtol)

Expand Down

0 comments on commit 900f96c

Please sign in to comment.