-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor DoRA to make it work with FSDP #1797
Refactor DoRA to make it work with FSDP #1797
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @BenjaminBossan, the fix is neat and doesn't involve a lot of code changes! ✨
We need to init the dora params with dummy values at the start.
tis was a silly idea
Update: Unfortunately, I could not get past the last few issues mentioned above. Therefore, I'll close this PR in favor of #1806. That PR has more code changes, so it would have been nice to get this one working, but I did not succeed. |
This PR moves all the DoRA functionality into a separate module class. Essentially, this is necessary because otherwise, the DoRA parameter lives on the lora.Linear layer as a parameter, not a separate module. Since FSDP auto wrap policy operates on the level of modules, not parameters, there is no way to modify the auto wrap policy to wrap the DoRA parameter, it must be its own module. If not for this reason, #1797 would be preferable, since the number of code changes is smaller overall. In this PR, there are more numerous changes, but the majority only involves moving code around, not any actual code changes. Since we introduce a new submodule, an extra steps are required to ensure that old DoRA state dicts can still be loaded correctly. This involves a fairly trivial extra remapping step in set_peft_model_state_dict. The test for this is performed via the new regression DoRA tests introduced in #1792. Similarly, there is a remapping step involved in get_peft_model_state_dict to ensure that when new state dicts with DoRA are saved, they still conform to the old format. An additional required change was to make a defensive copy of the base layer before dequantizing its weight in order to calculate the weight norm for DoRA. Without this defensive copy, some side-effect is triggered in FSDP that results in > ValueError: Cannot flatten integer dtype tensors even though the compute dtype of bnb is correctly set to float. Creating a fully functioning deepcopy does currently not work with 8bit BNB but there is a fix. Once the next BNB release is out, 8bit BNB will be tested and enabled. While working on this, I also noticed a small bug that dropout was not correctly applied when using QDoRA. This is now also fixed. This PR was tested successfully with FSDP and (Q)DoRA using the scripts in examples/sft/ with a modification to enable DoRA.
Description
With only a few changes, I managed to get DoRA running with FSDP.
Implementation
The first secret ingredient was to not use
lora_B.weight @ lora_A.weight
, as this sidesteps the__call__
of the module, which is required for FSDP to do some magic, thanks @pacman100 for sharing this secret knowledge with me. Instead, we now do:This should have the same output but uses
__call__
. It is probably a little bit slower though.(I ran all DoRA tests with an added assert that the two expressions return the same results to make sure.)
The second necessary step, required to make DoRA work with bitsandbytes (QDoRA), was to initialize the DoRA parameters lazily. This means: Normally, we initialize all the parameters when we initialize the layer. However, DoRA initialization requires us to dequantize the bnb weights. FSDP doesn't like that for some reason, as this results in FSDP trying to flatten int params.
Instead of eagerly initializing the DoRA weights, they are now lazily initialized during the forward step. This is not ideal, but at least it works.
Apart from this, only some minor changes were required. It was apparently not necessary to use
ModuleDict
instead ofParamDict
, which was my first attempt (which also works but is a much bigger refactor and changes thestate_dict
).Note: The change is only implemented for
lora.Linear
, not yet forlora.Conv2d
.TODOs
more of a reminder for myself
One problem still with this implementation is the lazy init of the DoRA weight, which is required as explained above. We need to put a placeholder value there, otherwise, we cannot successfully load trained DoRA models. However, if this placeholder value has
requires_grad=True
, we get an error becauserequires_grad
is not consistent when params are flattened by FSDP. When we setrequires_grad=False
, it makes it so that even if we later set the real value withrequires_grad=True
, it is still not updated.We could try updating
fsdp_auto_wrap_policy
so that the DoRA weight is not included in the same unit as the non-trainable weights. This works and allows us to initialize the FSDP model, but then we later getRuntimeError: CUDA error: an illegal memory access was encountered
.Changes to the scripts in
examples/sft
for this experiment:I could successfully run FSDP training with DoRA on a 2xT4 machine based on the PEFT examples in the
sft
folder. I made a few small changes to the scripts:run_peft_qlora_fsdp.sh
utils.py
Training log