Skip to content

Add an option to use fp8-all-gather only without fp8 computation. #1093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 31, 2024

Conversation

y-sq
Copy link
Contributor

@y-sq y-sq commented Oct 16, 2024

Summary:
The implementation reuses WeightWithDynamicFloat8CastTensor class and the Float8Linear module.

I added an if-else branch in the existing Float8Linear module to re-use our existing logics to handle different casting cases, such as pre-/post-forward for delayed scaling, pre-compute amax for fp8-all-gather.

Differential Revision: D63056142

Copy link

pytorch-bot bot commented Oct 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1093

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5c5f4f2 with merge base ae77f40 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 16, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63056142

@weifengpy
Copy link
Contributor

CI error looks relevant. maybe take a look before landing?

  ImportError while importing test module '/pytorch/ao/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py'.
  Hint: make sure your test modules/packages have valid Python names.
  Traceback:
  /opt/conda/envs/venv/lib/python3.9/importlib/__init__.py:127: in import_module
      return _bootstrap._gcd_import(name[level:], package, level)
  test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py:14: in <module>
      from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
  E   ModuleNotFoundError: No module named 'torch.distributed._composable.fsdp'```

@y-sq y-sq force-pushed the export-D63056142 branch from 7857bb1 to 2950a2e Compare October 31, 2024 05:48
y-sq added a commit to y-sq/ao that referenced this pull request Oct 31, 2024
…torch#1093)

Summary:

The implementation reuses `WeightWithDynamicFloat8CastTensor` class and the `Float8Linear` module.

I added an if-else branch in the existing `Float8Linear` module to re-use our existing logics to handle different casting cases, such as pre-/post-forward for delayed scaling, pre-compute amax for fp8-all-gather.

Reviewed By: weifengpy

Differential Revision: D63056142
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63056142

y-sq added a commit to y-sq/ao that referenced this pull request Oct 31, 2024
…torch#1093)

Summary:

The implementation reuses `WeightWithDynamicFloat8CastTensor` class and the `Float8Linear` module.

I added an if-else branch in the existing `Float8Linear` module to re-use our existing logics to handle different casting cases, such as pre-/post-forward for delayed scaling, pre-compute amax for fp8-all-gather.

Reviewed By: weifengpy

Differential Revision: D63056142
@y-sq y-sq force-pushed the export-D63056142 branch from 2950a2e to 0f3ef23 Compare October 31, 2024 05:48
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63056142

…torch#1093)

Summary:

The implementation reuses `WeightWithDynamicFloat8CastTensor` class and the `Float8Linear` module.

I added an if-else branch in the existing `Float8Linear` module to re-use our existing logics to handle different casting cases, such as pre-/post-forward for delayed scaling, pre-compute amax for fp8-all-gather.

Reviewed By: weifengpy

Differential Revision: D63056142
@y-sq y-sq force-pushed the export-D63056142 branch from 0f3ef23 to 5c5f4f2 Compare October 31, 2024 07:07
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D63056142

@facebook-github-bot facebook-github-bot merged commit 2761917 into pytorch:main Oct 31, 2024
17 of 19 checks passed
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
vkuzo added a commit that referenced this pull request Dec 19, 2024
…training

Summary:

In #1093 we added a config option, off
by default, to use only float8 all-gather for training and do the matrix
multiply in high precision.

This seems generally useful for communication bound workloads, but we
can probably think of a cleaner way to add this functionality (such as a
weight wrapper tensor subclass). The current implementation adds
non-trivial complexity and doesn't jive well with where we want to take this
codebase.

Since no one is using this internally or externally yet and we haven't talked
about it in the release notes, I think we should do a BC-breaking delete
as a one-off.  However, if people have concerns - let me know and we can
talk about less aggressive options.

Test Plan:

```
./test/float8/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo added a commit that referenced this pull request Dec 19, 2024
…training

Summary:

In #1093 we added a config option, off
by default, to use only float8 all-gather for training and do the matrix
multiply in high precision.

This seems generally useful for communication bound workloads, but we
can probably think of a cleaner way to add this functionality (such as a
weight wrapper tensor subclass). The current implementation adds
non-trivial complexity and doesn't jive well with where we want to take this
codebase.

Since no one is using this internally or externally yet and we haven't talked
about it in the release notes, I think we should do a BC-breaking delete
as a one-off.  However, if people have concerns - let me know and we can
talk about less aggressive options.

Test Plan:

```
./test/float8/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo added a commit that referenced this pull request Dec 20, 2024
#1451)

for now, delete the float8-all-gather-only functionality from float8 training

Summary:

In #1093 we added a config option, off
by default, to use only float8 all-gather for training and do the matrix
multiply in high precision.

This seems generally useful for communication bound workloads, but we
can probably think of a cleaner way to add this functionality (such as a
weight wrapper tensor subclass). The current implementation adds
non-trivial complexity and doesn't jive well with where we want to take this
codebase.

Since no one is using this internally or externally yet and we haven't talked
about it in the release notes, I think we should do a BC-breaking delete
as a one-off.  However, if people have concerns - let me know and we can
talk about less aggressive options.

Test Plan:

```
./test/float8/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
amdfaa pushed a commit that referenced this pull request Jan 10, 2025
#1451)

for now, delete the float8-all-gather-only functionality from float8 training

Summary:

In #1093 we added a config option, off
by default, to use only float8 all-gather for training and do the matrix
multiply in high precision.

This seems generally useful for communication bound workloads, but we
can probably think of a cleaner way to add this functionality (such as a
weight wrapper tensor subclass). The current implementation adds
non-trivial complexity and doesn't jive well with where we want to take this
codebase.

Since no one is using this internally or externally yet and we haven't talked
about it in the release notes, I think we should do a BC-breaking delete
as a one-off.  However, if people have concerns - let me know and we can
talk about less aggressive options.

Test Plan:

```
./test/float8/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants