Skip to content

mx: small speedup with dim0 cast #1980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 72 commits into from
Apr 1, 2025
Merged

mx: small speedup with dim0 cast #1980

merged 72 commits into from
Apr 1, 2025

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Mar 28, 2025

Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415

Reviewers:

Subscribers:

Tasks:

Tags:

vkuzo added 30 commits March 21, 2025 06:59
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added 5 commits March 28, 2025 13:00
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 28, 2025
Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

```bash
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 47fb1df
ghstack-comment-id: 2762318741
Pull Request resolved: #1980
vkuzo added 5 commits March 28, 2025 13:02
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 28, 2025
Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

```bash
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 47fb1df
ghstack-comment-id: 2762318741
Pull Request resolved: #1980
vkuzo added 4 commits March 28, 2025 13:03
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 28, 2025
Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

```bash
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 47fb1df
ghstack-comment-id: 2762318741
Pull Request resolved: #1980
vkuzo added 3 commits March 28, 2025 13:03
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 28, 2025
Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

```bash
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 47fb1df
ghstack-comment-id: 2762318741
Pull Request resolved: #1980
vkuzo added a commit that referenced this pull request Apr 1, 2025
Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

```bash
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 47fb1df
ghstack-comment-id: 2762318741
Pull Request resolved: #1980
vkuzo added 2 commits April 1, 2025 09:40
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Apr 1, 2025
Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

```bash
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 47fb1df
ghstack-comment-id: 2762318741
Pull Request resolved: #1980
[ghstack-poisoned]
@vkuzo vkuzo changed the base branch from gh/vkuzo/85/head to main April 1, 2025 16:41
vkuzo added a commit that referenced this pull request Apr 1, 2025
Summary:

Removes the unnecessary cast to bfloat16 in the MX dim0 casting code.
This is a 2.6% speedup on 16k by 16k shape:
https://www.internalfb.com/phabricator/paste/view/P1769373804

Note: this PR also includes a couple of cleanups around e8m0 dtype and
NaN handling, I found them while coding this PR. Leaving them together
instead of
separate PR since they are all safe.

Test Plan:

```bash
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 152.90741052631583
mem_bw_gbps 5321.488168553876
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (20250321_mx_dim1_triton_kernel)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.8.0a0+git25309a1
triton version: 3.3.0
mode: dim0_mx
time_us 149.03950980392162
mem_bw_gbps 5459.5924065404415
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 47fb1df
ghstack-comment-id: 2762318741
Pull Request resolved: #1980
@vkuzo vkuzo merged commit aafc1ba into main Apr 1, 2025
46 of 50 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: performance Use this tag if this PR improves the performance of a feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants