Skip to content

float8 moe training conversion API prototype #2275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented May 30, 2025

Stacked PRs:


float8 moe training conversion API prototype

  • convert_moe_to_float8_training will recursively swap nn.Parameter data tensors to a tensor subclass, which has an override for grouped_mm => dynamic quant + scaled grouped mm prototype. Context: see implementation of GroupedExperts here.

Testing

  • Tested with torchtitan (see PR with MoE conversion API) and confirmed single GPU training works as expected.

Limitations

  • Only supports single GPU training. I tried with FSDP=2 and hit this issue which seems to be related to a known issue that is being addressed.
  • Only performs grouped_mm override for routed experts (see condition here). For shared experts, I'll need to update the torchao prototype to support 3d A tensor.

danielvegamyhre added a commit that referenced this pull request May 30, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
Copy link

pytorch-bot bot commented May 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2275

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2b4361d with merge base d963a88 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 363236e to e84430e Compare May 30, 2025 02:38
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 30, 2025
danielvegamyhre added a commit that referenced this pull request May 30, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from e84430e to 6d76d3d Compare May 30, 2025 02:39
danielvegamyhre added a commit that referenced this pull request May 30, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from 6d76d3d to a71744c Compare May 30, 2025 03:17
@danielvegamyhre danielvegamyhre changed the title float8 moe training conversion API prototype float8 moe training conversion API prototype (single GPU training) May 30, 2025
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label May 30, 2025
danielvegamyhre added a commit that referenced this pull request May 30, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from a71744c to b72eabc Compare May 30, 2025 04:02
@danielvegamyhre danielvegamyhre changed the title float8 moe training conversion API prototype (single GPU training) float8 moe training conversion API prototype May 30, 2025
@danielvegamyhre
Copy link
Contributor Author

@drisspg @vkuzo for review when you have a chance

from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm


class ScaledGroupedMMTensor(torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh I suggested this Torch function approach to stay above autograd in order call into our autograd func, AFAIK compile's fucntion subclass support has gotten much better but just wanted to double check w/ you

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. How's compile looking on this now?

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • llama4 (what I tested with in torchtitan) cannot compile yet due to backward hooks implementing auxiliary loss not playing nicely with compile. Tianyu is aware of this and working on a new approach I believe.
  • In this minimal GroupedExperts example, it cannot compile with fullgraph=True due to _grouped_mm not being traceable yet: torch._dynamo.exc.Unsupported: .... Explanation: Dynamo does not know how to trace the builtin `torch._VariableFunctionsClass._grouped_mm. (I believe Brian is aware of this and working on it).
  • The minimal example can compile with graph breaks.

this was tested with the latest nightly pytorch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) yep, torch_function subclass compile support has improved in the last few months (thanks to Ryan / Lazos)

(2) on the grouped_mm issue, pytorch/pytorch#153384 should help, im going to land early next week but its a small change so feel free to rebase on it if you want to test sooner

return root_module


def convert_moe_to_float8_training(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also use the quantize_ API to match the rest of torchao. We want to eventually migrate the float8 training API to that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added todo for this

danielvegamyhre added a commit that referenced this pull request May 30, 2025
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from b72eabc to a10b3a0 Compare May 30, 2025 18:03
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/1 branch from a10b3a0 to 2b4361d Compare May 30, 2025 18:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants