Skip to content

Conversation

@danielvegamyhre
Copy link
Owner

@danielvegamyhre danielvegamyhre commented Jul 6, 2025

Stacked PRs:


Add bfloat16 support for colwise scaling

  • Add templating to enable TMA loads input data correctly for the given input data type (fp32 or bf16)
  • Always convert amax to fp32. This way we can use the same "amax to scale" function for both fp32 and bf16 input tensors.
  • Add tests for bfloat16 inputs
  • Update benchmark script to operate on bfloat16 inputs for dim1 CUDA cast, so we can compare triton vs CUDA implementations head to head.

Test plan

  • pytest test_kernels.py -k cuda_mx

Performance

Perf of this initial implementation is slightly worse than Triton (slightly lower mem bw utilization). I'm guessing it is at least in part due to the fact that I'm currently casting the bfloat 16 inputs to fp32 when moving from SMEM to registers. This was for simplicity to re-use the same "amax to scale" func as fp32. As a next step, I'll see if I can add a custom scale calc func for bf16 to avoid this.

Triton dim1 cast with bf16 input: ~2127gbps mem bw

(ao) [danvm@devgpu031.atn1 ~/private-torchao/benchmarks/mx_formats (work)]$ python cast_bench.py --mode dim1_mx_triton
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250702+cu128
triton version: 3.3.1
mode: dim1_mx_triton
time_us 382.4959993362427
mem_bw_gbps 2127.3293770706896

Cuda dim1 cast with bf16 input: ~1825gbps mem bw

(ao) [danvm@devgpu031.atn1 ~/private-torchao/benchmarks/mx_formats (work)]$ python cast_bench.py --mode dim1_cuda
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250702+cu128
triton version: 3.3.1
mode: dim1_cuda
time_us 445.82399725914
mem_bw_gbps 1825.148446477705


## Next steps
- Add custom scale calc func for bf16 to avoid unnecessary cast and improve perf 
- Investigate segfaults when inputs dims aren't divisible by 32

stack-info: PR: #10, branch: danielvegamyhre/stack/4
danielvegamyhre added a commit that referenced this pull request Jul 6, 2025
stack-info: PR: #10, branch: danielvegamyhre/stack/4
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/4 branch from dd5553c to 79e359b Compare July 6, 2025 00:14
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/3 to main July 6, 2025 00:24
danielvegamyhre added a commit that referenced this pull request Jul 6, 2025
stack-info: PR: #10, branch: danielvegamyhre/stack/4
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/4 branch from 79e359b to 2db90aa Compare July 6, 2025 00:24
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/3 July 6, 2025 00:24
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/3 to main July 6, 2025 01:40
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/3 July 6, 2025 01:40
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/3 to main July 6, 2025 01:49
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/3 July 6, 2025 01:49
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/3 to main July 6, 2025 03:10
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/3 July 6, 2025 03:11
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/3 to main July 6, 2025 20:11
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/3 July 6, 2025 20:12
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/3 to main July 6, 2025 21:39
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/3 July 6, 2025 21:39
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from cb95ba5 to 3827c8e Compare July 6, 2025 21:41
danielvegamyhre added a commit that referenced this pull request Jul 6, 2025
stack-info: PR: #10, branch: danielvegamyhre/stack/4
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/4 branch from 2db90aa to b4c0a32 Compare July 6, 2025 21:41
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/3 branch from 3827c8e to 85e627f Compare July 6, 2025 21:42
danielvegamyhre added a commit that referenced this pull request Jul 6, 2025
stack-info: PR: #10, branch: danielvegamyhre/stack/4
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/4 branch 2 times, most recently from 6b830dc to a5f03d8 Compare July 6, 2025 21:43
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/3 to main July 6, 2025 21:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants