-
Notifications
You must be signed in to change notification settings - Fork 171
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use checkpoint to enforece the recomputation of fp8 weight (#936)
Summary: Pull Request resolved: #936 The issue: When using float8 training with FSDP, we have these tensors in the forward_backward graph: - Without fp8-all-gather: original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward) - With fp8-all-gather: original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward) `torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization. ---- To fix it, we have different options: - In the user code, enforce which tensors to save for backward - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner. - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today. It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast. - Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile. Differential Revision: D63345959
- Loading branch information
1 parent
fbe97a0
commit 0c14069
Showing
2 changed files
with
77 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters