Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the option to use the corrected scaling factor for LoRA, based on new research. #1244

Merged
merged 7 commits into from
Dec 15, 2023

Conversation

Damjan-Kalajdzievski
Copy link
Contributor

@Damjan-Kalajdzievski Damjan-Kalajdzievski commented Dec 9, 2023

Hi, I am proposing to add an option to use the corrected scaling factor for LoRA, based on the recent paper A Rank Stabilization Scaling Factor for Fine-Tuning with LoRA. Try setting use_rslora = True in your LoraConfig for ranks greater than 32 and see the increase in fine-tuning performance (same or better performance for ranks lower than 32 as well).
Please feel free to suggest or change the implementation; I tried to go for the minimum code length change that implements this option.

Summary of method

For a LoRA adapter of rank $r$, the factor $\frac{\alpha}{r}$ that scales the adapter is too aggressive as a function of $r$, and slows learning for higher ranks so that no fine-tuning performance is gained over lower ranks. The paper A Rank Stabilization Scaling Factor for Fine-Tuning with LoRA proves theoretically and experimentally that we should be using an adapter scaling factor of $\frac{\alpha}{\sqrt{r}}$. This corrected scaling factor unlocks a compute/performance trade-off where increasing the rank increases the fine-tuning performance. This also corrects for the ongoing misconceptions that very low ranks not greater than 32 suffice for maximal performance, which entails the belief that the intrinsic dimensionality of fine-tuning is very low dimensional.

Description of changes

Added use_rslora bool in LoraConfig, which when set toTrue, corrects the scaling factor of adapters created with _create_and_replace in LoraModel. The variable use_rslora is set to False by default for backwards consistency.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

Thanks @Damjan-Kalajdzievski for the PR. I haven't checked it yet, but could you please run make style so that the CI can run?

@Damjan-Kalajdzievski
Copy link
Contributor Author

Damjan-Kalajdzievski commented Dec 11, 2023

Thanks @Damjan-Kalajdzievski for the PR. I haven't checked it yet, but could you please run make style so that the CI can run?

Hi @BenjaminBossan, I have run the command and commited the modification to the files I've changed.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this seemingly small but efficient initialization option to LoRA. I only skimmed the paper, but it seems that this options helps across the board, although it's most prominent for large ranks.

Regarding how it's implemented, I think we need to adjust the where we apply the scaling. Check my comment regarding that.

Furthermore, I would like to see two additions:

  • An entry in the docs about this new option here
  • A small unit test that checks the scaling factor after initializing a simple LoRA model. If you need help with that, let us know.

src/peft/tuners/lora/config.py Show resolved Hide resolved
@@ -57,6 +57,10 @@ class LoraConfig(PeftConfig):
bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only'. If 'all' or 'lora_only', the
corresponding biases will be updated during training. Be aware that this means that, even when disabling
the adapters, the model will not produce the same output as the base model would have without adaptation.
use_rslora (`bool`):
When set to True, uses <a href='https://doi.org/10.48550/arXiv.2312.03732'>Rank-Stabilized LoRA</a> which
sets the adapter scaling factor to the correct value of `lora_alpha/math.sqrt(r)`. Otherwise, it will use
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't say "correct value", maybe something like:

sets the adapter scaling factor to lora_alpha/math.sqrt(r), which was shown to work better.

metadata={
"help": (
"When set to True, uses "
"<a href='https://doi.org/10.48550/arXiv.2312.03732'>Rank-Stabilized LoRA</a> "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use html syntax in the help.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I interpreted this to mean that I should still include the url but without the html syntax.

"help": (
"When set to True, uses "
"<a href='https://doi.org/10.48550/arXiv.2312.03732'>Rank-Stabilized LoRA</a> "
"which sets the adapter scaling factor to the correct value "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same argument about "correct".

@@ -194,6 +195,13 @@ def _create_and_replace(
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)

if lora_config.use_rslora:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here is not the right place to control the scaling, as this leads to spreading the initialization of the LoRA parameters into different parts of the code. Instead, update_layer, update_layer_conv2d, and update_layer_embedding in tuners/lora/layer.py should be adjusted, since that's where we set the scale initially. This also requires updating the __init__ method to accept the new argument, as well as the kwargs variable here.

…odified peft/docs/source/conceptual_guides/lora.md to be consistent with the new LoraConfig and describe the use_rslora concept as suggested.
@Damjan-Kalajdzievski
Copy link
Contributor Author

Damjan-Kalajdzievski commented Dec 13, 2023

  • An entry in the docs about this new option here

I am not sure if it belongs in the description of the initializations since I didn't want to confuse a person unfamiliar with LoRA that the scaling factor is only applied at initialization, as instead it scales every adapter output. So here I commited a small section below the initialization section, describing the scaling factor and what use_rslora does. Not sure if this is the best place so feel free to suggest a change.

  • A small unit test that checks the scaling factor after initializing a simple LoRA model. If you need help with that, let us know.

I added a test in the test_initialization.py on commit 871ed7d . Let me know if this makes sense for what you envisioned the test to be like.

I think here is not the right place to control the scaling, as this leads to spreading the initialization of the LoRA parameters into different parts of the code. Instead, update_layer, update_layer_conv2d, and update_layer_embedding in tuners/lora/layer.py should be adjusted, since that's where we set the scale initially. This also requires updating the __init__ method to accept the new argument, as well as the kwargs variable here.

Makes sense, I had wondered about doing it this way also and will do this now. Not sure exactly what class' __init__ method you mean, but I added the argument use_rslora to all the update_* methods, and then feed it through the kwargs in _create_new_module so that every sublcass of LoraLayer calling update_layer has the use_rslora argument to pass along. Essentially treating use_rslora in the code exactly analogous to init_lora_weights. Added this now to commit a2c2f1a.
Thanks.

…n the update_layer type methods, as suggested
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the adjustments. From my point of view, there are only a few small changes needed and then the PR should be good to be merged.

I am not sure if it belongs in the description of the initializations since I didn't want to confuse a person unfamiliar with LoRA that the scaling factor is only applied at initialization, as instead it scales every adapter output.

IMO, we can still consider this to be initialization, as in almost all use cases, scaling is set once, at initialization, and then remains untouched.

I added a test in the test_initialization.py

Nicely done.

Not sure exactly what class' init method you mean, but I added the argument use_rslora to all the update_* methods, and then feed it through the kwargs in _create_new_module so that every sublcass of LoraLayer calling update_layer has the use_rslora argument to pass along. Essentially treating use_rslora in the code exactly analogous to init_lora_weights.

Yes, that was exactly what I meant, thanks!

tests/test_initialization.py Show resolved Hide resolved
src/peft/tuners/lora/config.py Show resolved Hide resolved
…in the conceptual guide to the initialization section
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks very good, nice tests.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Damjan-Kalajdzievski for adding the support for rank-stabilized LoRA scaling, LGTM! 🚀

@BenjaminBossan BenjaminBossan merged commit 997e6ec into huggingface:main Dec 15, 2023
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants