-
Notifications
You must be signed in to change notification settings - Fork 349
[mxfp8 moe training] add triton kernel for mxfp8 quantization along dim0 #3128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
009a6b8
to
c62b0f0
Compare
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3128
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 527317f with merge base cd21d0e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
c62b0f0
to
ff0f0c7
Compare
def to_mxfp8_dim0_kernel( | ||
x_ptr, # pointer to input tensor | ||
output_ptr, # pointer to output tensor (row-normalized) | ||
row_scale_ptr, # pointer to store row-wise maximum absolute values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should the comment say scale or max_abs? also row-wise should really say "across dim0"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the dim1 kernel said "column-wise" here, so i just changed it to say "row-wise" here for consistency. dim0/dim1 is also fine. The comment shouldn't "maximum absolute values" for sure though, I will remove that from both kernels.
) | ||
|
||
# Store the scales | ||
tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this either have a mask or a TODO to add masking?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added masking in #3131 where I refactored/simplified the kernel, but I'll go ahead and move it here
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
ff0f0c7
to
ec2dbf9
Compare
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
ec2dbf9
to
72933f9
Compare
import triton.language as tl | ||
from torch.library import triton_op, wrap_triton | ||
|
||
print("importing triton ops") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove please
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Woops, was debugging an import error, thanks
) | ||
|
||
# Store the scales | ||
tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this storing in swizzled layout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we have separate kernels for that. Low hanging fruit could be to write directly to blocked swizzled format here, like Cursor mentioned they're doing in this blog https://cursor.com/blog/kernels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its a non zero perf win for activation casting in inference. We can leave as a follow up but should be pretty easy to just have aconstexpr path do this
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
72933f9
to
cbf6277
Compare
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
cbf6277
to
527317f
Compare
Stacked PRs:
[mxfp8 moe training] add triton kernel for mxfp8 quantization along dim0
Summary
Test plan
pytest test/prototype/mx_formats/test_kernels.py
Benchmarks
existing torch.compile/to_mx() benchmark:
new triton dim0 mxfp8 kernel (~10% higher peak memory bandwidth utilization):