-
Notifications
You must be signed in to change notification settings - Fork 298
Use checkpoint to enforece the recomputation of fp8 weight #936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/936
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 40584c8 with merge base 9229df9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D63345959 |
This pull request was exported from Phabricator. Differential Revision: D63345959 |
Summary: Pull Request resolved: pytorch#936 The issue: When using float8 training with FSDP, we have these tensors in the forward_backward graph: - Without fp8-all-gather: original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward) - With fp8-all-gather: original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward) `torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization. ---- To fix it, we have different options: - In the user code, enforce which tensors to save for backward - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner. - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today. It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast. - Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile. Differential Revision: D63345959
2aee093
to
4a60a29
Compare
This pull request was exported from Phabricator. Differential Revision: D63345959 |
Summary: Pull Request resolved: pytorch#936 The issue: When using float8 training with FSDP, we have these tensors in the forward_backward graph: - Without fp8-all-gather: original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward) - With fp8-all-gather: original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward) `torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization. ---- To fix it, we have different options: - In the user code, enforce which tensors to save for backward - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner. - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today. It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast. - Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile. Differential Revision: D63345959
4a60a29
to
caacec5
Compare
This pull request was exported from Phabricator. Differential Revision: D63345959 |
Summary: Pull Request resolved: pytorch#936 The issue: When using float8 training with FSDP, we have these tensors in the forward_backward graph: - Without fp8-all-gather: original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward) - With fp8-all-gather: original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward) `torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization. ---- To fix it, we have different options: - In the user code, enforce which tensors to save for backward - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner. - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today. It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast. - Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile. Differential Revision: D63345959
caacec5
to
135e81b
Compare
This pull request was exported from Phabricator. Differential Revision: D63345959 |
Summary: Pull Request resolved: pytorch#936 The issue: When using float8 training with FSDP, we have these tensors in the forward_backward graph: - Without fp8-all-gather: original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward) - With fp8-all-gather: original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward) `torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization. ---- To fix it, we have different options: - In the user code, enforce which tensors to save for backward - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner. - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today. It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast. - Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile. Differential Revision: D63345959
135e81b
to
d4ad4ba
Compare
This pull request was exported from Phabricator. Differential Revision: D63345959 |
Summary: Pull Request resolved: pytorch#936 The issue: When using float8 training with FSDP, we have these tensors in the forward_backward graph: - Without fp8-all-gather: original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward) - With fp8-all-gather: original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward) `torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization. ---- To fix it, we have different options: - In the user code, enforce which tensors to save for backward - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner. - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today. It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast. - Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile. Differential Revision: D63345959
d4ad4ba
to
82efc6e
Compare
This pull request was exported from Phabricator. Differential Revision: D63345959 |
Summary: Pull Request resolved: pytorch#936 The issue: When using float8 training with FSDP, we have these tensors in the forward_backward graph: - Without fp8-all-gather: original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward) - With fp8-all-gather: original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward) `torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization. ---- To fix it, we have different options: - In the user code, enforce which tensors to save for backward - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner. - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today. It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast. - Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile. Differential Revision: D63345959
82efc6e
to
056c45b
Compare
This pull request was exported from Phabricator. Differential Revision: D63345959 |
Summary: Pull Request resolved: pytorch#936 The issue: When using float8 training with FSDP, we have these tensors in the forward_backward graph: - Without fp8-all-gather: original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward) - With fp8-all-gather: original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward) `torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization. ---- To fix it, we have different options: - In the user code, enforce which tensors to save for backward - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner. - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today. It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast. - Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile. Differential Revision: D63345959
056c45b
to
0c14069
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you! left a couple of nit comments on the internal version
weight_scale = self.get_weight_scale(self.weight) | ||
|
||
if self.force_recompute_fp8_weight_in_bwd: | ||
weight_fp8_t = checkpoint.checkpoint( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this force_recompute_fp8_weight_in_bwd
flag help ensure that the checkpoint is only done in compile?
In eager if someone has an outer AC context, the semantics are to do recursive checkpointing, i.e., a single tensor is computed multiple times. Nesting AC within SAC also hasn't been tested.
In the case of compile, the behavior is not really defined, but it might be possible that it would behave as you want - e.g. nested SAC policies where the inner policy overrode the outer policy, but I have not tested it.
But also we should add a note here that even though this might work today, in the future we'd want to replace it with some nicer API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment.
The most use case of float8 is with compile, because the eager performance is very bad.
Maybe I can add a warning that if "force_recompute_fp8_weight_in_bwd" is enabled, it's recommended to rely on torch.compile for activation checkpointing. Otherwise, if the users want to use customized AC, they should be sure to handle the checkpointing of weights themselves?
And I'll also add the note of "that even though this might work today, in the future we'd want to replace it with some nicer API." Shall we also open an issue to track the longer-term solution?
Summary: Pull Request resolved: pytorch#936 The issue: When using float8 training with FSDP, we have these tensors in the forward_backward graph: - Without fp8-all-gather: original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward) - With fp8-all-gather: original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward) `torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization. ---- To fix it, we have different options: - In the user code, enforce which tensors to save for backward - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner. - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today. It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast. - Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile. Reviewed By: vkuzo Differential Revision: D63345959
This pull request was exported from Phabricator. Differential Revision: D63345959 |
Differential Revision: D63345959 Pull Request resolved: pytorch#936
Summary:
The issue:
When using float8 training with FSDP, we have these tensors in the forward_backward graph:
original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward)
original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward)
torch.compile
decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization.To fix it, we have different options:
save_for_backward
in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner.Differential Revision: D63345959