Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DoRA to make it work with FSDP #1797

Closed

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented May 23, 2024

Description

With only a few changes, I managed to get DoRA running with FSDP.

Implementation

The first secret ingredient was to not use lora_B.weight @ lora_A.weight, as this sidesteps the __call__ of the module, which is required for FSDP to do some magic, thanks @pacman100 for sharing this secret knowledge with me. Instead, we now do:

x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device)
lora_weight = lora_B(lora_A(x_eye)).T

This should have the same output but uses __call__. It is probably a little bit slower though.
(I ran all DoRA tests with an added assert that the two expressions return the same results to make sure.)

The second necessary step, required to make DoRA work with bitsandbytes (QDoRA), was to initialize the DoRA parameters lazily. This means: Normally, we initialize all the parameters when we initialize the layer. However, DoRA initialization requires us to dequantize the bnb weights. FSDP doesn't like that for some reason, as this results in FSDP trying to flatten int params.

Instead of eagerly initializing the DoRA weights, they are now lazily initialized during the forward step. This is not ideal, but at least it works.

Apart from this, only some minor changes were required. It was apparently not necessary to use ModuleDict instead of ParamDict, which was my first attempt (which also works but is a much bigger refactor and changes the state_dict).

Note: The change is only implemented for lora.Linear, not yet for lora.Conv2d.

TODOs

more of a reminder for myself

One problem still with this implementation is the lazy init of the DoRA weight, which is required as explained above. We need to put a placeholder value there, otherwise, we cannot successfully load trained DoRA models. However, if this placeholder value has requires_grad=True, we get an error because requires_grad is not consistent when params are flattened by FSDP. When we set requires_grad=False, it makes it so that even if we later set the real value with requires_grad=True, it is still not updated.

We could try updating fsdp_auto_wrap_policy so that the DoRA weight is not included in the same unit as the non-trainable weights. This works and allows us to initialize the FSDP model, but then we later get RuntimeError: CUDA error: an illegal memory access was encountered.

Changes to the scripts in examples/sft for this experiment:

I could successfully run FSDP training with DoRA on a 2xT4 machine based on the PEFT examples in the sft folder. I made a few small changes to the scripts:

  • run_peft_qlora_fsdp.sh
3c3
< --model_name_or_path "meta-llama/Llama-2-70b-hf" \
---
> --model_name_or_path "facebook/opt-125m" \
9c9
< --max_seq_len 2048 \
---
> --max_seq_len 256 \
16,19c16
< --push_to_hub \
< --hub_private_repo True \
< --hub_strategy "every_save" \
< --bf16 True \
---
> --bf16 False \
26c23
< --output_dir "llama-sft-qlora-fsdp" \
---
> --output_dir "/tmp/llama-sft-qlora-fsdp" \
33c30
< --use_flash_attn True \
---
> --use_flash_attn False \
41,42c38,39
< --bnb_4bit_compute_dtype "bfloat16" \
< --bnb_4bit_quant_storage_dtype "bfloat16"
\ No newline at end of file
---
> --bnb_4bit_compute_dtype "float32" \
> --bnb_4bit_quant_storage_dtype "float32"
  • utils.py
141a142,143
>         use_dora = bool(int(os.environ.get("USE_DORA"))) 
>         print("*"*20, "use dora:", use_dora)
147a150
>             use_dora=use_dora,

Training log

***** Running training *****
  Num examples = 48,106
  Num Epochs = 1
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 2
  Total optimization steps = 6,013
  Number of trainable parameters = 663,552
{'loss': 2.8955, 'grad_norm': 0.8952780365943909, 'learning_rate': 9.999982939289716e-05, 'epoch': 0.0}                                                  
{'loss': 2.7834, 'grad_norm': 0.9351043105125427, 'learning_rate': 9.999931757275294e-05, 'epoch': 0.0}                                                  
{'loss': 2.8371, 'grad_norm': 0.9547858834266663, 'learning_rate': 9.999846454306009e-05, 'epoch': 0.0}                                                  
{'loss': 2.6302, 'grad_norm': 0.8862084150314331, 'learning_rate': 9.999727030964001e-05, 'epoch': 0.0}                                                  
{'loss': 2.8222, 'grad_norm': 1.0635148286819458, 'learning_rate': 9.999573488064242e-05, 'epoch': 0.0}                                                  
{'loss': 2.8932, 'grad_norm': 0.9164345860481262, 'learning_rate': 9.999385826654554e-05, 'epoch': 0.0}                                                  
{'loss': 2.7997, 'grad_norm': 0.9818331599235535, 'learning_rate': 9.999164048015593e-05, 'epoch': 0.01}                                                 
{'loss': 2.8716, 'grad_norm': 1.1899254322052002, 'learning_rate': 9.998908153660838e-05, 'epoch': 0.01}                                                 
{'loss': 2.7126, 'grad_norm': 0.914239227771759, 'learning_rate': 9.998618145336587e-05, 'epoch': 0.01}                                                  
{'loss': 2.7325, 'grad_norm': 1.0930601358413696, 'learning_rate': 9.998294025021936e-05, 'epoch': 0.01}                                                 
{'loss': 2.6289, 'grad_norm': 0.9594664573669434, 'learning_rate': 9.997935794928776e-05, 'epoch': 0.01}                                                 
{'loss': 2.7155, 'grad_norm': 1.0767607688903809, 'learning_rate': 9.997543457501773e-05, 'epoch': 0.01}                                                 
{'loss': 2.6928, 'grad_norm': 1.0032813549041748, 'learning_rate': 9.997117015418345e-05, 'epoch': 0.01}                                                 
{'loss': 2.645, 'grad_norm': 1.168573260307312, 'learning_rate': 9.996656471588657e-05, 'epoch': 0.01}                                                   
{'loss': 2.6361, 'grad_norm': 1.0172500610351562, 'learning_rate': 9.996161829155588e-05, 'epoch': 0.01}                                                 
{'loss': 2.8198, 'grad_norm': 1.0332210063934326, 'learning_rate': 9.995633091494722e-05, 'epoch': 0.01}                                                 
{'loss': 2.72, 'grad_norm': 1.0368295907974243, 'learning_rate': 9.995070262214313e-05, 'epoch': 0.01}                                                   
{'loss': 2.7599, 'grad_norm': 1.1201494932174683, 'learning_rate': 9.994473345155267e-05, 'epoch': 0.01}                                                 
{'loss': 2.5562, 'grad_norm': 1.09053373336792, 'learning_rate': 9.993842344391118e-05, 'epoch': 0.02}                                                   
{'loss': 2.7267, 'grad_norm': 1.0421711206436157, 'learning_rate': 9.993177264227992e-05, 'epoch': 0.02}                                                 
{'loss': 2.7317, 'grad_norm': 1.1445740461349487, 'learning_rate': 9.992478109204589e-05, 'epoch': 0.02}                                                 
{'loss': 2.7285, 'grad_norm': 1.1827253103256226, 'learning_rate': 9.991744884092137e-05, 'epoch': 0.02}                                                 
{'loss': 2.6398, 'grad_norm': 1.0205488204956055, 'learning_rate': 9.990977593894374e-05, 'epoch': 0.02}                                                 
{'loss': 2.599, 'grad_norm': 1.026229977607727, 'learning_rate': 9.990176243847507e-05, 'epoch': 0.02}                                                   
{'loss': 2.6922, 'grad_norm': 1.0691609382629395, 'learning_rate': 9.989340839420176e-05, 'epoch': 0.02}                                                 
{'loss': 2.6191, 'grad_norm': 1.0511975288391113, 'learning_rate': 9.988471386313418e-05, 'epoch': 0.02}                                                 
{'loss': 2.8006, 'grad_norm': 1.0725524425506592, 'learning_rate': 9.987567890460628e-05, 'epoch': 0.02}                                                 
{'loss': 2.6992, 'grad_norm': 1.1552668809890747, 'learning_rate': 9.98663035802752e-05, 'epoch': 0.02}                                                  
{'loss': 2.6909, 'grad_norm': 1.0143084526062012, 'learning_rate': 9.985658795412079e-05, 'epoch': 0.02}                                                 
{'loss': 2.5932, 'grad_norm': 1.0270745754241943, 'learning_rate': 9.984653209244525e-05, 'epoch': 0.02}                                                 
{'loss': 2.6703, 'grad_norm': 1.142627477645874, 'learning_rate': 9.983613606387265e-05, 'epoch': 0.03}                                                  
{'loss': 2.5607, 'grad_norm': 1.0679482221603394, 'learning_rate': 9.982539993934844e-05, 'epoch': 0.03}                                                 
{'loss': 2.6616, 'grad_norm': 1.0160728693008423, 'learning_rate': 9.981432379213898e-05, 'epoch': 0.03}                                                 
{'loss': 2.6676, 'grad_norm': 1.163256049156189, 'learning_rate': 9.980290769783103e-05, 'epoch': 0.03}                                                  
{'loss': 2.7486, 'grad_norm': 1.0142438411712646, 'learning_rate': 9.979115173433128e-05, 'epoch': 0.03}                                                 
{'loss': 2.5851, 'grad_norm': 1.120428442955017, 'learning_rate': 9.977905598186578e-05, 'epoch': 0.03}                                                  
{'loss': 2.7134, 'grad_norm': 1.0812691450119019, 'learning_rate': 9.976662052297935e-05, 'epoch': 0.03}                                                 
{'loss': 2.6443, 'grad_norm': 1.1822906732559204, 'learning_rate': 9.97538454425351e-05, 'epoch': 0.03}                                                  
{'loss': 2.4838, 'grad_norm': 1.39701247215271, 'learning_rate': 9.974073082771382e-05, 'epoch': 0.03}                                                   
{'loss': 2.6165, 'grad_norm': 1.145723581314087, 'learning_rate': 9.972727676801338e-05, 'epoch': 0.03}                                                  
{'loss': 2.7184, 'grad_norm': 1.1211601495742798, 'learning_rate': 9.971348335524808e-05, 'epoch': 0.03}                                                 
{'loss': 2.6372, 'grad_norm': 1.024659276008606, 'learning_rate': 9.969935068354807e-05, 'epoch': 0.03}                                                  
{'loss': 2.6092, 'grad_norm': 1.1009517908096313, 'learning_rate': 9.968487884935878e-05, 'epoch': 0.04}                                                 
{'loss': 2.5787, 'grad_norm': 1.0226079225540161, 'learning_rate': 9.967006795144006e-05, 'epoch': 0.04}                                                 
{'loss': 2.6843, 'grad_norm': 1.041178822517395, 'learning_rate': 9.96549180908657e-05, 'epoch': 0.04}                                                   
{'loss': 2.6676, 'grad_norm': 1.0906169414520264, 'learning_rate': 9.963942937102269e-05, 'epoch': 0.04}                                                 
{'loss': 2.5395, 'grad_norm': 1.0848253965377808, 'learning_rate': 9.962360189761041e-05, 'epoch': 0.04}                                                 
{'loss': 2.6135, 'grad_norm': 1.0817619562149048, 'learning_rate': 9.960743577864004e-05, 'epoch': 0.04}                                                 
{'loss': 2.7096, 'grad_norm': 1.1412609815597534, 'learning_rate': 9.959093112443378e-05, 'epoch': 0.04}                                                 
{'loss': 2.664, 'grad_norm': 1.0265809297561646, 'learning_rate': 9.957408804762409e-05, 'epoch': 0.04}                                                  
{'loss': 2.5626, 'grad_norm': 1.040503978729248, 'learning_rate': 9.955690666315289e-05, 'epoch': 0.04}                                                  
{'loss': 2.6042, 'grad_norm': 1.0967168807983398, 'learning_rate': 9.953938708827086e-05, 'epoch': 0.04}                                                 
{'loss': 2.5536, 'grad_norm': 1.0510889291763306, 'learning_rate': 9.952152944253651e-05, 'epoch': 0.04}                                                 
{'loss': 2.5904, 'grad_norm': 1.0969544649124146, 'learning_rate': 9.950333384781552e-05, 'epoch': 0.04}                                                 
{'loss': 2.6716, 'grad_norm': 1.1797243356704712, 'learning_rate': 9.94848004282798e-05, 'epoch': 0.05}                                                  
{'loss': 2.5291, 'grad_norm': 1.082594633102417, 'learning_rate': 9.946592931040666e-05, 'epoch': 0.05}                                                  
{'loss': 2.5242, 'grad_norm': 1.0399702787399292, 'learning_rate': 9.944672062297795e-05, 'epoch': 0.05}                                                 
  5%|█████▎                                                                                                         | 286/6013 [08:50<2:49:19,  1.77s/it]

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @BenjaminBossan, the fix is neat and doesn't involve a lot of code changes! ✨

We need to init the dora params with dummy values at the start.
tis was a silly idea
@BenjaminBossan
Copy link
Member Author

Update: Unfortunately, I could not get past the last few issues mentioned above. Therefore, I'll close this PR in favor of #1806. That PR has more code changes, so it would have been nice to get this one working, but I did not succeed.

BenjaminBossan added a commit that referenced this pull request May 31, 2024
This PR moves all the DoRA functionality into a separate module class.
Essentially, this is necessary because otherwise, the DoRA parameter
lives on the lora.Linear layer as a parameter, not a separate module.
Since FSDP auto wrap policy operates on the level of modules, not
parameters, there is no way to modify the auto wrap policy to wrap the
DoRA parameter, it must be its own module.

If not for this reason, #1797 would be preferable, since the number of
code changes is smaller overall. In this PR, there are more numerous
changes, but the majority only involves moving code around, not any
actual code changes.

Since we introduce a new submodule, an extra steps are required to
ensure that old DoRA state dicts can still be loaded correctly. This
involves a fairly trivial extra remapping step in
set_peft_model_state_dict. The test for this is performed via the new
regression DoRA tests introduced in #1792.

Similarly, there is a remapping step involved in
get_peft_model_state_dict to ensure that when new state dicts with DoRA
are saved, they still conform to the old format.

An additional required change was to make a defensive copy of the base
layer before dequantizing its weight in order to calculate the weight
norm for DoRA. Without this defensive copy, some side-effect is
triggered in FSDP that results in

> ValueError: Cannot flatten integer dtype tensors

even though the compute dtype of bnb is correctly set to float.

Creating a fully functioning deepcopy does currently not work with 8bit
BNB but there is a fix. Once the next BNB release is out, 8bit BNB will
be tested and enabled.

While working on this, I also noticed a small bug that dropout was not
correctly applied when using QDoRA. This is now also fixed.

This PR was tested successfully with FSDP and (Q)DoRA using the scripts
in examples/sft/ with a modification to enable DoRA.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants