Skip to content

Commit

Permalink
FIX: Sft train script FSDP QLoRA embedding mean resizing error (huggi…
Browse files Browse the repository at this point in the history
…ngface#2151)

Resizing the embedding layer with mean_resizing=True, which has been
introduced in transformers > 4.45, will result in an error. This is
because for FSDP + QLoRA the embedding matrix can be on meta device, in
which case mean resizing fails. Therefore, if these conditions are
detected, the script will set mean_resizing=False.

Also updated the recommended package versions to newer versions that I
have checked to be working.
  • Loading branch information
BenjaminBossan authored and yaswanth19 committed Oct 20, 2024
1 parent 1c701dc commit 0a8f42e
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/source/accelerate/deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ You can also refer this blog post [Falcon 180B Finetuning using 🤗 PEFT and De
# Use PEFT QLoRA and DeepSpeed with ZeRO3 for finetuning large models on multiple GPUs
In this section, we will look at how to use QLoRA and DeepSpeed Stage-3 for finetuning 70B llama model on 2X40GB GPUs.
For this, we first need `bitsandbytes>=0.43.0`, `accelerate>=0.28.0`, `transformers>4.38.2`, `trl>0.7.11` and `peft>0.9.0`. We need to set `zero3_init_flag` to true when using Accelerate config. Below is the config which can be found at [deepspeed_config_z3_qlora.yaml](https://github.com/huggingface/peft/blob/main/examples/sft/configs/deepspeed_config_z3_qlora.yaml):
For this, we first need `bitsandbytes>=0.43.3`, `accelerate>=1.0.1`, `transformers>4.44.2`, `trl>0.11.4` and `peft>0.13.0`. We need to set `zero3_init_flag` to true when using Accelerate config. Below is the config which can be found at [deepspeed_config_z3_qlora.yaml](https://github.com/huggingface/peft/blob/main/examples/sft/configs/deepspeed_config_z3_qlora.yaml):
```yml
compute_environment: LOCAL_MACHINE
Expand Down
2 changes: 1 addition & 1 deletion docs/source/accelerate/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ In the above example, the memory consumed per GPU is 72-80 GB (90-98%) as seen

In this section, we will look at how to use QLoRA and FSDP for finetuning 70B llama model on 2X24GB GPUs. [Answer.AI](https://www.answer.ai/) in collaboration with bitsandbytes and Hugging Face 🤗 open sourced code enabling the usage of FSDP+QLoRA and explained the whole process in their insightful blogpost [You can now train a 70b language model at home](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html). This is now integrated in Hugging Face ecosystem.

For this, we first need `bitsandbytes>=0.43.0`, `accelerate>=0.28.0`, `transformers>4.38.2`, `trl>0.7.11` and `peft>0.9.0`. We need to set `fsdp_cpu_ram_efficient_loading=true`, `fsdp_use_orig_params=false` and `fsdp_offload_params=true`(cpu offloading) when using Accelerate config. When not using accelerate launcher, you can alternately set the environment variable `export FSDP_CPU_RAM_EFFICIENT_LOADING=true`. Here, we will be using accelerate config and below is the config which can be found at [fsdp_config_qlora.yaml](https://github.com/huggingface/peft/blob/main/examples/sft/configs/fsdp_config_qlora.yaml):
For this, we first need `bitsandbytes>=0.43.3`, `accelerate>=1.0.1`, `transformers>4.44.2`, `trl>0.11.4` and `peft>0.13.0`. We need to set `fsdp_cpu_ram_efficient_loading=true`, `fsdp_use_orig_params=false` and `fsdp_offload_params=true`(cpu offloading) when using Accelerate config. When not using accelerate launcher, you can alternately set the environment variable `export FSDP_CPU_RAM_EFFICIENT_LOADING=true`. Here, we will be using accelerate config and below is the config which can be found at [fsdp_config_qlora.yaml](https://github.com/huggingface/peft/blob/main/examples/sft/configs/fsdp_config_qlora.yaml):

```yml
compute_environment: LOCAL_MACHINE
Expand Down
2 changes: 2 additions & 0 deletions examples/sft/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ When you have access to multiple GPUs, it would be better to use normal LoRA wit
## Multi-GPU SFT with LoRA and FSDP
When you have access to multiple GPUs, it would be better to use normal LoRA with DeepSpeed/FSDP. To use LoRA with DeepSpeed, refer the docs at [PEFT with FSDP](https://huggingface.co/docs/peft/accelerate/fsdp).

## Tip

Generally try to upgrade to the latest package versions for best results, especially when it comes to `bitsandbytes`, `accelerate`, `transformers`, `trl`, and `peft`.
13 changes: 12 additions & 1 deletion examples/sft/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from enum import Enum

import packaging.version
import torch
import transformers
from datasets import DatasetDict, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError
from transformers import (
Expand Down Expand Up @@ -169,8 +171,17 @@ def create_and_prepare_model(args, data_args, training_args):
trust_remote_code=True,
)
tokenizer.chat_template = chat_template

# make embedding resizing configurable?
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
# Transformers 4.46.0+ defaults uses mean_resizing by default, which fails with QLoRA + FSDP because the
# embedding could be on meta device, therefore, we set mean_resizing=False in that case (i.e. the status quo
# ante). See https://github.com/huggingface/accelerate/issues/1620.
uses_transformers_4_46 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.46.0")
uses_fsdp = os.environ.get("ACCELERATE_USE_FSDP").lower() == "true"
if (bnb_config is not None) and uses_fsdp and uses_transformers_4_46:
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8, mean_resizing=False)
else:
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
Expand Down

0 comments on commit 0a8f42e

Please sign in to comment.