-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Implement DoRA #1474
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement DoRA #1474
Changes from all commits
6242f4a
32ffc50
53a5e25
18fb476
e4a677f
c42bb0a
b915ea2
8bd22d7
ecf7160
07f3e43
5b22170
f15dbae
1e6d1d7
69a81a6
4a90843
951ae67
9c4edc1
4bf7346
ebdda07
6194726
51d4919
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -101,6 +101,13 @@ class LoraConfig(PeftConfig): | |
| The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone weights | ||
| and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a | ||
| quantized model in this case, as LoftQ will quantize the model itself. | ||
| use_dora (`bool`): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice docstrings. |
||
| Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the weights | ||
| into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is | ||
| handled by a separate learnable parameter. This can improve the performance of LoRA, especially at low | ||
| ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces a bigger overhead than | ||
| pure LoRA, so it is recommended to merge weights for inference. For more information, see | ||
| https://arxiv.org/abs/2402.09353. | ||
| """ | ||
|
|
||
| r: int = field(default=8, metadata={"help": "Lora attention dimension"}) | ||
|
|
@@ -224,6 +231,19 @@ class LoraConfig(PeftConfig): | |
| ) | ||
| }, | ||
| ) | ||
| use_dora: bool = field( | ||
| default=False, | ||
| metadata={ | ||
| "help": ( | ||
| "Enable 'Weight-Decomposed Low-Rank Adaptation' (DoRA). This technique decomposes the updates of the " | ||
| "weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the " | ||
| "magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, " | ||
| "especially at low ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces " | ||
| "a bigger overhead than pure LoRA, so it is recommended to merge weights for inference. For more " | ||
| "information, see https://arxiv.org/abs/2402.09353." | ||
| ) | ||
| }, | ||
| ) | ||
|
|
||
| def __post_init__(self): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we block some mis-intended usage ? e.g. if one passes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a check for loftq and megatron. |
||
| self.peft_type = PeftType.LORA | ||
|
|
@@ -238,6 +258,9 @@ def __post_init__(self): | |
| if isinstance(self.target_modules, str) and self.layers_pattern is not None: | ||
| raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") | ||
|
|
||
| if self.use_dora and (self.megatron_config or self.init_lora_weights == "loftq"): | ||
| raise ValueError("DoRA does not support megatron_core or LoftQ. Please set `use_dora=False`.") | ||
|
|
||
| # handle init_lora_weights and loftq_config | ||
| if self.init_lora_weights == "loftq": | ||
| import importlib | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.