Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use checkpoint to enforece the recomputation of fp8 weight #936

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

y-sq
Copy link
Contributor

@y-sq y-sq commented Sep 24, 2024

Summary:
The issue:
When using float8 training with FSDP, we have these tensors in the forward_backward graph:

  • Without fp8-all-gather:
    original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward)
  • With fp8-all-gather:
    original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward)

torch.compile decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization.


To fix it, we have different options:

  • In the user code, enforce which tensors to save for backward
    • The save_for_backward in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner.
    • [This pr] Using "torch.utils.checkpoint", which is the API that compile does promise to respect today. It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast.
  • Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile.

Differential Revision: D63345959

Copy link

pytorch-bot bot commented Sep 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/936

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0c14069 with merge base fbe97a0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 24, 2024
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D63345959

@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D63345959

y-sq added a commit to y-sq/ao that referenced this pull request Sep 24, 2024
Summary:
Pull Request resolved: pytorch#936

The issue:
When using float8 training with FSDP, we have these tensors in the forward_backward graph:
- Without fp8-all-gather:
original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward)
- With fp8-all-gather:
original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward)

`torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization.

----
To fix it, we have different options:
- In the user code, enforce which tensors to save for backward
  - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner.
  - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today.  It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast.
- Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile.

Differential Revision: D63345959
@weifengpy
Copy link
Contributor

cc @H-Huang @tianyu-l as this might be related with PP + float8 memory increase in torch.compile

@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D63345959

y-sq added a commit to y-sq/ao that referenced this pull request Sep 25, 2024
Summary:
Pull Request resolved: pytorch#936

The issue:
When using float8 training with FSDP, we have these tensors in the forward_backward graph:
- Without fp8-all-gather:
original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward)
- With fp8-all-gather:
original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward)

`torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization.

----
To fix it, we have different options:
- In the user code, enforce which tensors to save for backward
  - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner.
  - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today.  It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast.
- Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile.

Differential Revision: D63345959
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D63345959

y-sq added a commit to y-sq/ao that referenced this pull request Sep 25, 2024
Summary:
Pull Request resolved: pytorch#936

The issue:
When using float8 training with FSDP, we have these tensors in the forward_backward graph:
- Without fp8-all-gather:
original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward)
- With fp8-all-gather:
original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward)

`torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization.

----
To fix it, we have different options:
- In the user code, enforce which tensors to save for backward
  - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner.
  - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today.  It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast.
- Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile.

Differential Revision: D63345959
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D63345959

y-sq added a commit to y-sq/ao that referenced this pull request Sep 25, 2024
Summary:
Pull Request resolved: pytorch#936

The issue:
When using float8 training with FSDP, we have these tensors in the forward_backward graph:
- Without fp8-all-gather:
original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward)
- With fp8-all-gather:
original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward)

`torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization.

----
To fix it, we have different options:
- In the user code, enforce which tensors to save for backward
  - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner.
  - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today.  It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast.
- Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile.

Differential Revision: D63345959
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D63345959

y-sq added a commit to y-sq/ao that referenced this pull request Sep 25, 2024
Summary:
Pull Request resolved: pytorch#936

The issue:
When using float8 training with FSDP, we have these tensors in the forward_backward graph:
- Without fp8-all-gather:
original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward)
- With fp8-all-gather:
original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward)

`torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization.

----
To fix it, we have different options:
- In the user code, enforce which tensors to save for backward
  - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner.
  - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today.  It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast.
- Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile.

Differential Revision: D63345959
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D63345959

y-sq added a commit to y-sq/ao that referenced this pull request Sep 25, 2024
Summary:
Pull Request resolved: pytorch#936

The issue:
When using float8 training with FSDP, we have these tensors in the forward_backward graph:
- Without fp8-all-gather:
original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward)
- With fp8-all-gather:
original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward)

`torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization.

----
To fix it, we have different options:
- In the user code, enforce which tensors to save for backward
  - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner.
  - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today.  It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast.
- Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile.

Differential Revision: D63345959
Summary:
Pull Request resolved: pytorch#936

The issue:
When using float8 training with FSDP, we have these tensors in the forward_backward graph:
- Without fp8-all-gather:
original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward)
- With fp8-all-gather:
original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward)

`torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization.

----
To fix it, we have different options:
- In the user code, enforce which tensors to save for backward
  - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner.
  - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today.  It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast.
- Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile.

Differential Revision: D63345959
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D63345959

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants