-
Notifications
You must be signed in to change notification settings - Fork 0
add bfloat16 support for colwise scaling #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
stack-info: PR: #10, branch: danielvegamyhre/stack/4
danielvegamyhre
added a commit
that referenced
this pull request
Jul 6, 2025
stack-info: PR: #10, branch: danielvegamyhre/stack/4
dd5553c to
79e359b
Compare
This was referenced Jul 6, 2025
Merged
danielvegamyhre
added a commit
that referenced
this pull request
Jul 6, 2025
stack-info: PR: #10, branch: danielvegamyhre/stack/4
79e359b to
2db90aa
Compare
cb95ba5 to
3827c8e
Compare
danielvegamyhre
added a commit
that referenced
this pull request
Jul 6, 2025
stack-info: PR: #10, branch: danielvegamyhre/stack/4
2db90aa to
b4c0a32
Compare
3827c8e to
85e627f
Compare
danielvegamyhre
added a commit
that referenced
this pull request
Jul 6, 2025
stack-info: PR: #10, branch: danielvegamyhre/stack/4
6b830dc to
a5f03d8
Compare
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
Add bfloat16 support for colwise scaling
Test plan
pytest test_kernels.py -k cuda_mxPerformance
Perf of this initial implementation is slightly worse than Triton (slightly lower mem bw utilization). I'm guessing it is at least in part due to the fact that I'm currently casting the bfloat 16 inputs to fp32 when moving from SMEM to registers. This was for simplicity to re-use the same "amax to scale" func as fp32. As a next step, I'll see if I can add a custom scale calc func for bf16 to avoid this.
Triton dim1 cast with bf16 input: ~2127gbps mem bw
Cuda dim1 cast with bf16 input: ~1825gbps mem bw