-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for gradient checkpointing for LLM fine-tuning #3613
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice description! Do you have a sense for which LLMs support gradient checkpointing and which ones don't?
I think almost all models coming from the |
This PR adds support in the finetune trainer to optionally enabled
gradient_checkpointing
.What is gradient checkpointing?
Gradient checkpointing works by recomputing the activations of the model during the backward pass, rather than storing them in memory during the forward pass. This is a tradeoff between compute and memory, as the activations need to be recomputed during the backward pass, but the memory footprint is reduced. This is set to false by default because it is not always beneficial to use gradient checkpointing, and it can sometimes slow down training.
How can you use gradient checkpointing in the config?
Gradient checkpointing is disabled by default. To enable it, you can simply set
enable_gradient_checkpointing
to True.When should I enable gradient checkpointing?
This is useful when training very large models that run into out of memory errors very quickly during training. It is particularly helpful when doing non-quantization based training (adapter based or full fine-tuning).