Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New changes in PEFT breaks FSDP (RLHF) #1836

Closed
6 of 8 tasks
maziyarpanahi opened this issue Aug 20, 2024 · 12 comments · Fixed by #1848
Closed
6 of 8 tasks

New changes in PEFT breaks FSDP (RLHF) #1836

maziyarpanahi opened this issue Aug 20, 2024 · 12 comments · Fixed by #1848
Labels
bug Something isn't working

Comments

@maziyarpanahi
Copy link
Contributor

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

The FSDP is supported and it works properly with DPO up to 608a2f3 commit tag. However, after this, it crashes.

Current behaviour

With the latest changes in the main, it takes a very long time before reaching Map the dataset. Then it crashes.

Steps to reproduce

  • launch a RunPod instance with 8x A100 or H100 with 640G memory (or 320G)
  • choose winglian/axolotl-runpod:main-latest template
  • follow these steps
rm -rf axolotl
git clone https://github.com/OpenAccess-AI-Collective/axolotl && \
cd axolotl && \
git checkout e299312 && \
pip install setuptools && \
pip install -e .[flash-attn,deepspeed] && \
cd ..
  • then preprocess and train the config.yaml

Config yaml

base_model: arcee-ai/Arcee-Nova
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: false
load_in_4bit: true
strict: false

save_safetensors: true

rl: dpo
chat_template: chatml
datasets:
  - path: Intel/orca_dpo_pairs
    split: train
    type: chatml.intel

dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./models/Arcee-Nova-DPO-v0.1

adapter: qlora
lora_model_dir:

sequence_len: 1800
sample_packing: false
pad_to_sequence_len: false

lora_r: 64
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
# lora_modules_to_save:
#   - embed_tokens
#   - lm_head
#lora_target_modules:
#  - gate
#  - q_proj
#  - k_proj
#  - v_proj
#  - o_proj
#  - w1
#  - w2
#  - w3

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 5e-5
train_on_inputs: false
group_by_length: false

bf16: auto
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 50
evals_per_epoch: 1
eval_table_size:
eval_table_max_new_tokens: 128
save_steps: 100
debug:
weight_decay: 0.05
fsdp:
   - full_shard
   - auto_wrap
fsdp_config:
  fsdp_limit_all_gathers: true
  fsdp_sync_module_states: true
  fsdp_offload_params: true
  fsdp_use_orig_params: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
  fsdp_state_dict_type: FULL_STATE_DICT
special_tokens:
  pad_token: "<|endoftext|>"
  eos_token: "<|im_end|>"

Possible solution

must checkout prior to the change that broken FSDP: git checkout 608a2f3

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.10

axolotl branch-commit

e299312

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@maziyarpanahi maziyarpanahi added the bug Something isn't working label Aug 20, 2024
@winglian
Copy link
Collaborator

for whomever looks into this, the error is
ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32

we usually run ensure_dtype to make sure all the modules are in the same dtype, but something is still in float32

@winglian
Copy link
Collaborator

also, here's the diff on transformers from the changes in the commit that works huggingface/transformers@0fdea86...v4.43.1

@winglian
Copy link
Collaborator

@maziyarpanahi just to clarify, does must checkout prior to the change that broken FSDP: git checkout 608a2f3 mean that 608a2f3 on axolotl works or is broken?

@winglian
Copy link
Collaborator

also, did you reinstall transformers to the whatever version is in requirements.txt at the axolotl commit that works?

@winglian
Copy link
Collaborator

I tried reverting axolotl to the git sha you pointed out as well as downgraded transformers and it qwen2 still doesn't RL finetune. If you can provide more info on how to reproduce a working finetune on qwen2, I think I can track down from there what is wrong.

@maziyarpanahi
Copy link
Contributor Author

@maziyarpanahi just to clarify, does must checkout prior to the change that broken FSDP: git checkout 608a2f3 mean that 608a2f3 on axolotl works or is broken?

608a2f3 this tag works. (and anything before this/

also, did you reinstall transformers to the whatever version is in requirements.txt at the axolotl commit that works?

Not just transformers, I checkout 608a2f3 and did a full axolotl installation. I am guessing transformers also got re-installed

I tried reverting axolotl to the git sha you pointed out as well as downgraded transformers and it qwen2 still doesn't RL finetune. If you can provide more info on how to reproduce a working finetune on qwen2, I think I can track down from there what is wrong.

that's strange, I launch a new RunPod and just run this script to make the DPO works:

# remove the existing axolotl direcotry
rm -rf axolotl
# clone and install
git clone https://github.com/OpenAccess-AI-Collective/axolotl && \
cd axolotl && \
git checkout 608a2f3 && \
pip install setuptools && \
pip install -e .[flash-attn,deepspeed] && \
cd ..

@winglian
Copy link
Collaborator

winglian commented Aug 21, 2024

everything was first downgraded to accelerate-0.32.0 axolotl-0.4.1 bitsandbytes-0.43.1 datasets-2.19.1 deepspeed-0.14.3+bc48371c flash-attn-2.6.1 fsspec-2024.3.1 gcsfs-2024.3.1 peft-0.11.1 s3fs-2024.3.1 scikit-learn-1.2.2 transformers-4.43.1

upgrade individually & test:
transformers==4.43.2 ✅
transformers==4.43.3 ✅
transformers==4.44.0 ✅
axolotl@main ✅
accelerate==0.33.0 ✅
peft==0.12.0 ❌

peft@39c60ffca9c1d1cc606a16654cfe9cd66b363a70 ❌
peft@47745d57c2ab110ce854f76b279ac03ead63c12c ✅
peft@3cf5359f112fedae2ffd28412cfc95076263e5d3 ❌
peft@cb7aedd9ba6642dda543d176ead5b5247d112e2e ✅

@BenjaminBossan it seems this huggingface/peft#1742 is the problematic changeset that breaks qlora+DPO+FSDP. Also, i checked regular bf16 lora+DPO+FSDP and that works fine, so it's only a 4bit quantization issue with peft.

@winglian
Copy link
Collaborator

adding this to the end of peft.tuners.tuners_utils._move_adapter_to_device_of_base_layer seems to fix the issue when using peft@3cf5359f112fedae2ffd28412cfc95076263e5d3 or peft@main. Although I don't know the ramifications of what that might break for the other adapter types.

self.to(device, dtype)

@winglian
Copy link
Collaborator

@maziyarpanahi in the meantime, you can simply git clone the peft repo, uninstall peft, and then do a pip install -e . in the peft directory and them modify src/peft/tuners/tuners_utils.py per the last comment.

@winglian winglian changed the title New changes in transformers breaks FSDP (RLHF) New changes in PEFT breaks FSDP (RLHF) Aug 21, 2024
@winglian
Copy link
Collaborator

Also, i believe this issue is specific to the Qwen2 models, as llama didn't have this issue.

@BenjaminBossan
Copy link

BenjaminBossan commented Aug 22, 2024

Thanks for making me aware of this issue. This should not happen, so let's do our best to resolve this.

I cannot spin up a machine as suggested for reproducing the issue. Therefore, I tried something with my local setup of 2 GPUs. I basically took the peft/examples/sft script and made some changes to use Qwen/Qwen2-1.5B / Qwen/Qwen2-7B, 4bit bitsandbytes, and bf16 with FSDP. Unfortunately, I cannot replicate the error, training works for me. I'm not sure if it's due to the smaller model, the lack of DPO, or something else.

@winglian what is the bnb config that is effectively used when using this axolotl example?

adding this to the end of peft.tuners.tuners_utils._move_adapter_to_device_of_base_layer seems to fix the issue when using peft@3cf5359f112fedae2ffd28412cfc95076263e5d3 or peft@main. Although I don't know the ramifications of what that might break for the other adapter types.

self.to(device, dtype)

To give more context, the idea of that PR was to allow having multiple, say, LoRA adapters that reside on different devices. Before that, we would just move the whole module to the device (and potentially dtype) of the base model. The change you suggested would effectively restore that behavior. My suspicion of what must be happening is that some part of the Qwen module is not converted anymore after the change.

@winglian Would it be possible for you to store all the devices and dtypes from before and after you additional line and check if they're different, in which case they get printed? So something like this:

        # detect if some parameters have not yet been moved to the correct dtype/device
        dtypes_before = [(n, p.dtype) for n, p in self.named_parameters()]
        devices_before = [(n, p.device) for n, p in self.named_parameters()]
        self.to(device, dtype)  # <= the line you added
        dtypes_after = [(n, p.dtype) for n, p in self.named_parameters()]
        devices_after = [(n, p.device) for n, p in self.named_parameters()]
        if (dtypes_before != dtypes_after) or (devices_before != devices_after):
            print(self)
            print("dtypes before", dtypes_before)
            print("dtypes after", dtypes_after)
            print("devices before", devices_before)
            print("devices after", devices_after)
            1/0  # raise error to exit

For me, this does not detect any change and thus passes.

Edit: Made some changes to the snippet to also show the param names.

@winglian
Copy link
Collaborator

looks like the bias modules were missed

lora.Linear4bit(                                                                                                                             
  (base_layer): Linear4bit(in_features=1536, out_features=1536, bias=True)            
  (lora_dropout): ModuleDict(                                                                                                                
    (default): Dropout(p=0.05, inplace=False)                                                                                                
  )                                                                                                                                          
  (lora_A): ModuleDict(                                                                                                                      
    (default): Linear(in_features=1536, out_features=32, bias=False)                                                                         
  )                                                                                                                                          
  (lora_B): ModuleDict(                                                                                                                      
    (default): Linear(in_features=32, out_features=1536, bias=False)                                                                         
  )                                                                                                                                          
  (lora_embedding_A): ParameterDict()                                                                                                        
  (lora_embedding_B): ParameterDict()                                                                                                        
  (lora_magnitude_vector): ModuleDict()                                                                                                      
)                                                                                                                                            
dtypes before [('base_layer.weight', torch.bfloat16), ('base_layer.bias', torch.float32), ('lora_A.default.weight', torch.bfloat16), ('lora_B.default.weight', torch.bfloat16)]                                                                                                            
dtypes after [('base_layer.weight', torch.bfloat16), ('base_layer.bias', torch.bfloat16), ('lora_A.default.weight', torch.bfloat16), ('lora_B.default.weight', torch.bfloat16)]                                                                                                            
devices before [('base_layer.weight', device(type='cpu')), ('base_layer.bias', device(type='cpu')), ('lora_A.default.weight', device(type='cpu')), ('lora_B.default.weight', device(type='cpu'))]                                                                                          
devices after [('base_layer.weight', device(type='cpu')), ('base_layer.bias', device(type='cpu')), ('lora_A.default.weight', device(type='cpu')), ('lora_B.default.weight', device(type='cpu'))]                                                                                           

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants