-
Notifications
You must be signed in to change notification settings - Fork 369
[mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm #2848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2848
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below:
✅ You can merge normally! (1 Unrelated Failure)As of commit f70cc90 with merge base 8722c0c ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
a5a29db to
e8759f2
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
e8759f2 to
2638c34
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
2638c34 to
b249e0c
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
b249e0c to
57d96f2
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
57d96f2 to
24ac553
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
24ac553 to
2772a69
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
2772a69 to
6444a5e
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
6444a5e to
b975201
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
b975201 to
2c8371d
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
2c8371d to
7320c14
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
7320c14 to
e3a6297
Compare
|
This PR needs to be approved by an authorized maintainer before merge. |
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
e3a6297 to
3cd9a50
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
3cd9a50 to
ad54e2b
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
ad54e2b to
cabb470
Compare
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
cabb470 to
0dd17b5
Compare
| A_scale: torch.Tensor, | ||
| B_t_mx: torch.Tensor, | ||
| B_t_scale: torch.Tensor, | ||
| B_mx: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: _mx should be for a combination of raw data and scale, if B_mx is just the data then better to call it something else
| blocked_scales: Tensor | ||
| start_row_after_padding: Tensor of shape (num_groups,) which contains the start row after padding for each group. | ||
| """ | ||
| from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import _to_blocked |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the same function as the one we have in torchao?
ao/torchao/prototype/mx_formats/utils.py
Line 18 in 6f035e8
| def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will test if they're the same and replace if we can - I was having trouble getting the kernel working without CUDA errors so was trying to minimize differences between fbgemm unit test code and this torchao code path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be the exact same
vkuzo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamping since this is prototype
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
0dd17b5 to
f70cc90
Compare
Stacked PRs:
[mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm
Summary
output = input @ weight^Tgrad_input = grad_output @ weightto_blocked_per_group_2d(for input scales) andto_blocked_per_group_3d(for weight scales). These are pytorch reference implementations that are not performant. We can implement equivalent triton kernels for them later.xtensorx_scalestensor and must be sizelen(group_sizes) + 1where the first starting row is always 0, and each value corresponds to the starting row of group[i] in thex_scalestensor AFTER padding_emulated_mxfp8_scaled_grouped_mm_2d_3dto have same function signature and input constraints as the fbgemm APITest plan
pytest test/prototype/moe_training/test_scaled_grouped_mm.py -k mxpytest test/prototype/moe_training/test_training.py -k mx