-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
15 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
d9f190f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically, we expect that the trainable params should be float32 in mixed precision training, i.e., fp16 or bf16 is enabled.
The DeepSpeed or FSDP engine automatically casts the dtype of trainable params and non-trainable params in training, thus we do not need to set
torch_dtype
when initializing models under DeepSpeed or FSDP. The models should be loaded in thefloat32
type on the CPU. (It seems that DeepSpeed ZeRO-3 initializes models in thefloat16
orbfloat16
type on the CUDA)Contrarily, if we do not use DeepSpeed or FSDP, the trainer cannot automatically cast the dtype of model params, we need to manually handle the dtype and device map when initializing models (DoRA initialization needs CUDA devices for float16 models). The models should be loaded in the
float16
orbfloat16
type on the CUDA. Then we cast the trainable params to float32 for training stability.Empirical results:
Related materials:
[1] huggingface/peft#1249
[2] huggingface/peft#1336
[3] huggingface/peft#1706
[4] huggingface/trl#1644
d9f190f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @BenjaminBossan
We have a small experiment on the dtype of adapter weights, there may be some useful information.
d9f190f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, using
cast_trainable_params_to_fp32
(i.e.,param.data = param.data.to(torch.float32)
) may cause hanging in DeepSpeed ZeRO-3 and FSDP examples.d9f190f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the ping. Also pinging @pacman100
I'm not sure I 100% understand the conclusion. So do you think the changes in huggingface/peft#1706 are correct or is there something else we should do?
d9f190f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BenjaminBossan Just to share my findings, I think the changes in huggingface/peft#1706 should be fine. The information above is merely experimental observations, and we don't have a clear conclusion (since DeepSpeed's behaviour is too complex to understand).
d9f190f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Loading the model on CPU consumes large RAM in distributed training, so we rollback the dtype and device setting for DeepSpeed non-zero3 circumstances: 31a0564
d9f190f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, thanks for letting me know. Don't hesitate to open an issue if you think we should adjust something on the PEFT side.
d9f190f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc: https://huggingface.co/docs/accelerate/concept_guides/fsdp_and_deepspeed#on-differences-in-data-precision-handling