Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Oct 7, 2025

Stacked PRs:


[mxfp8 moe training] add triton kernel for mxfp8 quantization along dim0

Summary

  • torch.compile codegen has had on and off perf regressions during the development process of mxfp8 moe training, it would be nice to have a simple triton kernel to have consistently good perf we can use

Test plan

  • pytest test/prototype/mx_formats/test_kernels.py

Benchmarks

existing torch.compile/to_mx() benchmark:

(torch) [danvm@devgpu031.atn1 ~/ao/benchmarks (main)]$ CUDA_VISIBLE_DEVICES=7 python mx_formats/cast_bench.py --mode dim0_mxfp8_floor
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.10.0.dev20251008+cu128
triton version: 3.5.0
mode: dim0_mxfp8_floor
time_us 156.5759927034378
mem_bw_gbps 5196.805474139168

new triton dim0 mxfp8 kernel (~10% higher peak memory bandwidth utilization):

(torch) [danvm@devgpu031.atn1 ~/ao/benchmarks (danielvegamyhre/stack/76)]$ CUDA_VISIBLE_DEVICES=7 python mx_formats/cast_bench.py --mode dim0_mxfp8_triton_floor
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.10.0.dev20251008+cu128
triton version: 3.5.0
mode: dim0_mxfp8_triton_floor
time_us 140.35199582576752
mem_bw_gbps 5797.530496182742

danielvegamyhre added a commit that referenced this pull request Oct 7, 2025
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/75 branch from 009a6b8 to c62b0f0 Compare October 7, 2025 17:34
Copy link

pytorch-bot bot commented Oct 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3128

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 527317f with merge base cd21d0e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 7, 2025
@danielvegamyhre danielvegamyhre added mx topic: not user facing Use this tag if you don't want this PR to show up in release notes moe labels Oct 7, 2025
danielvegamyhre added a commit that referenced this pull request Oct 15, 2025
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/75 branch from c62b0f0 to ff0f0c7 Compare October 15, 2025 15:24
def to_mxfp8_dim0_kernel(
x_ptr, # pointer to input tensor
output_ptr, # pointer to output tensor (row-normalized)
row_scale_ptr, # pointer to store row-wise maximum absolute values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the comment say scale or max_abs? also row-wise should really say "across dim0"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the dim1 kernel said "column-wise" here, so i just changed it to say "row-wise" here for consistency. dim0/dim1 is also fine. The comment shouldn't "maximum absolute values" for sure though, I will remove that from both kernels.

)

# Store the scales
tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this either have a mask or a TODO to add masking?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added masking in #3131 where I refactored/simplified the kernel, but I'll go ahead and move it here

danielvegamyhre added a commit that referenced this pull request Oct 15, 2025
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/75 branch from ff0f0c7 to ec2dbf9 Compare October 15, 2025 15:43
danielvegamyhre added a commit that referenced this pull request Oct 15, 2025
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/75 branch from ec2dbf9 to 72933f9 Compare October 15, 2025 15:48
import triton.language as tl
from torch.library import triton_op, wrap_triton

print("importing triton ops")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woops, was debugging an import error, thanks

)

# Store the scales
tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this storing in swizzled layout?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we have separate kernels for that. Low hanging fruit could be to write directly to blocked swizzled format here, like Cursor mentioned they're doing in this blog https://cursor.com/blog/kernels

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a non zero perf win for activation casting in inference. We can leave as a follow up but should be pretty easy to just have aconstexpr path do this

danielvegamyhre added a commit that referenced this pull request Oct 15, 2025
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/75 branch from 72933f9 to cbf6277 Compare October 15, 2025 16:00
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/75 branch from cbf6277 to 527317f Compare October 15, 2025 16:08
@danielvegamyhre danielvegamyhre merged commit 664124a into main Oct 15, 2025
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants