diff --git a/src/peft/tuners/lora/dora.py b/src/peft/tuners/lora/dora.py index 859c294f10..95c5b253f7 100644 --- a/src/peft/tuners/lora/dora.py +++ b/src/peft/tuners/lora/dora.py @@ -107,6 +107,32 @@ def __repr__(self) -> str: return "lora.dora." + rep +class DoraEmbeddingLayer(DoraLinearLayer): + def forward(self, x, *, lora_A, lora_B, scaling, base_layer, embed_fn): + """ + For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer + output. + """ + lora_weight = (lora_A @ lora_B).T + magnitude = self.weight + weight = base_layer.weight + weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling) + # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) + # "[...] we suggest treating ||V +∆V ||_c in + # Eq. (5) as a constant, thereby detaching it from the gradient + # graph. This means that while ||V + ∆V ||_c dynamically + # reflects the updates of ∆V , it won’t receive any gradient + # during backpropagation" + weight_norm = weight_norm.detach() + mag_norm_scale = magnitude / weight_norm + result_dora = mag_norm_scale * (embed_fn(x, lora_A) @ lora_B) * scaling + return mag_norm_scale, result_dora + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora.dora." + rep + + class DoraConv2dLayer(DoraLinearLayer): def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: # calculate L2 norm of weight matrix, column-wise diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index ed7c1a74f1..8a5dc9970f 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -29,7 +29,7 @@ from peft.utils.other import transpose from .config import LoraConfig -from .dora import DoraConv2dLayer, DoraLinearLayer +from .dora import DoraConv2dLayer, DoraEmbeddingLayer, DoraLinearLayer class LoraLayer(BaseTunerLayer): @@ -594,9 +594,6 @@ def __init__( super().__init__() LoraLayer.__init__(self, base_layer) - if use_dora: - raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") - self._active_adapter = adapter_name self.update_layer( adapter_name, @@ -635,9 +632,31 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig elif init_lora_weights: self.reset_lora_parameters(adapter_name, init_lora_weights) + # call this before dora_init self._move_adapter_to_device_of_base_layer(adapter_name) + + if use_dora: + self.dora_init(adapter_name) + self.use_dora[adapter_name] = True + else: + self.use_dora[adapter_name] = False + self.set_adapter(self.active_adapters) + def dora_init(self, adapter_name: str) -> None: + if self.lora_magnitude_vector is None: + # first dora layer being added, add lora_magnitude_vector to the list of learnable parameters + self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",) + + dora_layer = DoraEmbeddingLayer(fan_in_fan_out=True) + lora_embedding_A = self.lora_embedding_A[adapter_name] + lora_embedding_B = self.lora_embedding_B[adapter_name] + scaling = self.scaling[adapter_name] + dora_layer.update_layer( + base_layer=self.get_base_layer(), lora_A=lora_embedding_A, lora_B=lora_embedding_B, scaling=scaling + ) + self.lora_magnitude_vector[adapter_name] = dora_layer + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: """ Merge the active adapter weights into the base weights @@ -785,8 +804,20 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: embedding_A = self.lora_embedding_A[active_adapter].T embedding_B = self.lora_embedding_B[active_adapter].T scaling = self.scaling[active_adapter] - after_A = self._embed(x, embedding_A) - result = result + (after_A @ embedding_B) * scaling + + if not self.use_dora[active_adapter]: + after_A = self._embed(x, embedding_A) + result = result + (after_A @ embedding_B) * scaling + else: + mag_norm_scale, dora_result = self.lora_magnitude_vector[active_adapter]( + x, + lora_A=embedding_A, + lora_B=embedding_B, + scaling=scaling, + base_layer=self.get_base_layer(), + embed_fn=self._embed, + ) + result = mag_norm_scale * result + dora_result result = result.to(torch_result_dtype) return result diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 27f367a536..52bda6f8bd 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -88,6 +88,19 @@ ("Embedding + transformers Conv1D 1 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["conv1d"]}), ("Embedding + transformers Conv1D 2 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb"]}), ("Embedding + transformers Conv1D 3 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb", "conv1d"]}), + ( + "Embedding + transformers Conv1D 1 DoRA", + "EmbConv1D", + LoraConfig, + {"target_modules": ["conv1d"], "use_dora": True}, + ), + ("Embedding + transformers Conv1D 2 DoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb"], "use_dora": True}), + ( + "Embedding + transformers Conv1D 3 DoRA", + "EmbConv1D", + LoraConfig, + {"target_modules": ["emb", "conv1d"], "use_dora": True}, + ), ("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}), ("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}), ("Conv2d 1 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}), @@ -874,7 +887,9 @@ def test_only_params_are_updated(self, test_name, model_id, config_cls, config_k model_before = copy.deepcopy(model) model.train() - optimizer = torch.optim.SGD(model.parameters(), lr=0.5) + # this high learning rate was found through testing to be necessary to avoid flakiness + lr = 100.0 if config_kwargs.get("use_dora") and model_id == "EmbConv1D" else 0.5 + optimizer = torch.optim.SGD(model.parameters(), lr=lr) # train at least 3 steps for all parameters to be updated (probably this is required because of symmetry # breaking of some LoRA layers that are initialized with constants) @@ -943,7 +958,6 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs): X = self.prepare_inputs_for_testing() model = self.transformers_class.from_pretrained(model_id).to(self.torch_device).eval() - outputs_base = model(**X) if issubclass(config_cls, FourierFTConfig): config_kwargs = config_kwargs.copy() @@ -1046,6 +1060,9 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co if issubclass(config_cls, IA3Config) and model_id == "Conv2d": # more instability with Conv2d + IA3 atol, rtol = 1e-3, 1e-3 + if config_kwargs.get("use_dora") and model_id == "EmbConv1D": + atol, rtol = 1e-4, 1e-4 + # check that there is a difference in results after training assert not torch.allclose(outputs_before, outputs_after, atol=atol, rtol=rtol)