-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update to prepare_model_for_kbit_training
#728
Conversation
The documentation is not available anymore as the PR was closed or merged. |
from deprecated `prepare_model_for_int8_training` and add `use_gradient_checkpointing=args.gradient_checkpointing` to automatically follow the gradient checkpointing choice is also the workaround for huggingface#694
973429d
to
16ef09b
Compare
calling model.gradient_checkpointing_enable() twice causes issues this workaround calls it in prepare_model_for_kbit_training and then changes the arg to false to make sure it isn't called again in huggingface trainer inner loop also changes stack_llama_2 sft trainer to use correct device map for ddp training so that you can test this issue
I've realized this fix actually causes an issue. Calling I've made a workaround here in To demonstrate the issue, I fixed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me! Thanks for deepdiving and explaining, I left one question, let me know what do you think
model, use_gradient_checkpointing=args.gradient_checkpointing | ||
) | ||
|
||
args = dataclasses.replace(args, gradient_checkpointing=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this change here and not above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do want to call gradient_checkpointing_enable
once, we just don't want to call it twice. We will call it in 'prepare_for_kbit_trainingbut this change makes sure we don't call it in
Trainer`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect makes sense!
@mnoukhov thanks again for your work on this ! Would you be happy to fix the merge conflicts? After that we should be good to merge! |
Pulled and should be ready to merge! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this great effort!
* update to `prepare_model_for_kbit_training` from deprecated `prepare_model_for_int8_training` and add `use_gradient_checkpointing=args.gradient_checkpointing` to automatically follow the gradient checkpointing choice is also the workaround for huggingface#694 * workaround for gradient checkpointing issue calling model.gradient_checkpointing_enable() twice causes issues this workaround calls it in prepare_model_for_kbit_training and then changes the arg to false to make sure it isn't called again in huggingface trainer inner loop also changes stack_llama_2 sft trainer to use correct device map for ddp training so that you can test this issue
* update to `prepare_model_for_kbit_training` from deprecated `prepare_model_for_int8_training` and add `use_gradient_checkpointing=args.gradient_checkpointing` to automatically follow the gradient checkpointing choice is also the workaround for huggingface#694 * workaround for gradient checkpointing issue calling model.gradient_checkpointing_enable() twice causes issues this workaround calls it in prepare_model_for_kbit_training and then changes the arg to false to make sure it isn't called again in huggingface trainer inner loop also changes stack_llama_2 sft trainer to use correct device map for ddp training so that you can test this issue
since
peft
has deprecatedprepare_model_for_int8_training
also add
use_gradient_checkpointing=args.gradient_checkpointing
to automatically follow the gradient checkpointing choice in training argsFor RewardTrainer, this is the workaround to #480 proposed by #694.
Concurrently @lewtun is working on #726 which adds the
use_gradient_checkpointing
for RewardTrainer. I'm happy to wait until it is merged to merge this.