-
Notifications
You must be signed in to change notification settings - Fork 273
float8 moe training conversion API prototype #2275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2275
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2b4361d with merge base d963a88 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
363236e
to
e84430e
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
e84430e
to
6d76d3d
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
6d76d3d
to
a71744c
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
a71744c
to
b72eabc
Compare
from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm | ||
|
||
|
||
class ScaledGroupedMMTensor(torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh I suggested this Torch function approach to stay above autograd in order call into our autograd func, AFAIK compile's fucntion subclass support has gotten much better but just wanted to double check w/ you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. How's compile looking on this now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- llama4 (what I tested with in torchtitan) cannot compile yet due to backward hooks implementing auxiliary loss not playing nicely with compile. Tianyu is aware of this and working on a new approach I believe.
- In this minimal GroupedExperts example, it cannot compile with
fullgraph=True
due to _grouped_mm not being traceable yet:torch._dynamo.exc.Unsupported: .... Explanation: Dynamo does not know how to trace the builtin `torch._VariableFunctionsClass._grouped_mm.
(I believe Brian is aware of this and working on it). - The minimal example can compile with graph breaks.
this was tested with the latest nightly pytorch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(1) yep, torch_function subclass compile support has improved in the last few months (thanks to Ryan / Lazos)
(2) on the grouped_mm issue, pytorch/pytorch#153384 should help, im going to land early next week but its a small change so feel free to rebase on it if you want to test sooner
return root_module | ||
|
||
|
||
def convert_moe_to_float8_training( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can also use the quantize_
API to match the rest of torchao. We want to eventually migrate the float8 training API to that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added todo for this
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
b72eabc
to
a10b3a0
Compare
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
a10b3a0
to
2b4361d
Compare
Stacked PRs:
float8 moe training conversion API prototype
convert_moe_to_float8_training
will recursively swap nn.Parameter data tensors to a tensor subclass, which has an override for grouped_mm => dynamic quant + scaled grouped mm prototype. Context: see implementation of GroupedExperts here.Testing
Limitations