Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp #1127

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

warpuv
Copy link

@warpuv warpuv commented Oct 11, 2024

According to the docs (https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) forward() method should not be called directly, apply() method have to be used instead. After removing forward call, activation checkpointing starts working.

What does this PR do?

Fixes #1126

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

According to the docs (https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) forward() method should not be called directly, apply() method have to be used instead.
After removing forward call, activation checkpointing starts working.
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 11, 2024
@pansershrek
Copy link

@danthe3rd Can you review this PR please? This change fixes the integration with FSDP + activation_checkpointing

Copy link
Contributor

@lw lw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

It seems you're basically trying to revert the change introduced in #706. I don't have full context on that old PR, but I wonder whether there are ways to achieve both goals at once.

For example, we could take your PR and also replace the check if (ctx == nullptr) above with if (x.required_grad). I believe this should get us everything we want?

@warpuv warpuv force-pushed the main branch 2 times, most recently from 6033c6c to bb5ef21 Compare October 17, 2024 17:27
The IF conditional on the x.requires_grad state (to change the behavior between inference/training modes) changes behavior of the recomputation of the forward() method which breaks activation checkpointing
(as on recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method).
@warpuv
Copy link
Author

warpuv commented Oct 17, 2024

Thanks for the PR!

It seems you're basically trying to revert the change introduced in #706. I don't have full context on that old PR, but I wonder whether there are ways to achieve both goals at once.

For example, we could take your PR and also replace the check if (ctx == nullptr) above with if (x.required_grad). I believe this should get us everything we want?

@lw thank you for your review and suggestions.
The “if” conditional on x.requires_grad changes the behavior of the recomputation of the forward since x.requires_grad has different value as it is detached on recomputation phase, and in turn save_for_backward is not called.
I have pushed an alternative solution using torch::GradMode::is_enabled(), I believe both goals are achieved this way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Activation checkpointing is not working on SwiGLU
4 participants