Skip to content

mx: small speedup with dim0 cast #1980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 72 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
af6ae2f
Update
vkuzo Mar 21, 2025
45120de
Update
vkuzo Mar 21, 2025
5527e72
Update
vkuzo Mar 21, 2025
478b9e1
Update
vkuzo Mar 21, 2025
571775d
Update
vkuzo Mar 21, 2025
fd30558
Update
vkuzo Mar 21, 2025
b0cd056
Update
vkuzo Mar 21, 2025
26b49fd
Update
vkuzo Mar 21, 2025
ba10a02
Update
vkuzo Mar 21, 2025
483cdfd
Update
vkuzo Mar 21, 2025
32005c9
Update
vkuzo Mar 21, 2025
e341c2e
Update
vkuzo Mar 24, 2025
7ecd79f
Update
vkuzo Mar 24, 2025
ca3c4cf
Update
vkuzo Mar 24, 2025
0de11cf
Update
vkuzo Mar 24, 2025
912e4dc
Update
vkuzo Mar 24, 2025
fb5662a
Update
vkuzo Mar 25, 2025
f245d64
Update
vkuzo Mar 26, 2025
9e5b8f8
Update
vkuzo Mar 26, 2025
e5bdecb
Update
vkuzo Mar 26, 2025
4c2ad8c
Update
vkuzo Mar 26, 2025
c1ceef1
Update
vkuzo Mar 26, 2025
65bfff0
Update
vkuzo Mar 26, 2025
0ff3a93
Update
vkuzo Mar 26, 2025
71a5548
Update
vkuzo Mar 26, 2025
0576d0d
Update
vkuzo Mar 26, 2025
f98453f
Update
vkuzo Mar 27, 2025
81dc214
Update
vkuzo Mar 27, 2025
5d60f24
Update
vkuzo Mar 27, 2025
a313055
Update
vkuzo Mar 27, 2025
798abfc
Update
vkuzo Mar 27, 2025
4933b66
Update
vkuzo Mar 27, 2025
d9e60c1
Update
vkuzo Mar 27, 2025
884f065
Update
vkuzo Mar 27, 2025
41b1f9d
Update
vkuzo Mar 27, 2025
5cc2755
Update
vkuzo Mar 27, 2025
af1f386
Update
vkuzo Mar 27, 2025
8691bd4
Update
vkuzo Mar 27, 2025
1a0993d
Update
vkuzo Mar 27, 2025
b053f97
Update
vkuzo Mar 27, 2025
9e335ce
Update
vkuzo Mar 27, 2025
87756f9
Update
vkuzo Mar 28, 2025
d0a0fd1
Update
vkuzo Mar 28, 2025
cf9dfe4
Update
vkuzo Mar 28, 2025
beafdd9
Update
vkuzo Mar 28, 2025
45abedf
Update
vkuzo Mar 28, 2025
af87eee
Update
vkuzo Mar 28, 2025
db67393
Update
vkuzo Mar 28, 2025
a679de7
Update
vkuzo Mar 28, 2025
28dedc0
Update
vkuzo Mar 28, 2025
1ffb62b
Update
vkuzo Mar 28, 2025
02d5065
Update
vkuzo Mar 28, 2025
d1bf83a
Update
vkuzo Mar 28, 2025
84c77d7
Update
vkuzo Mar 28, 2025
749564b
Update
vkuzo Mar 28, 2025
f63479e
Update
vkuzo Mar 28, 2025
8b0f250
Update
vkuzo Mar 28, 2025
c603f09
Update
vkuzo Mar 28, 2025
42fb0e9
Update
vkuzo Mar 28, 2025
a16f576
Update
vkuzo Mar 28, 2025
b890654
Update
vkuzo Mar 28, 2025
35179ab
Update
vkuzo Mar 28, 2025
83e1e2e
Update
vkuzo Mar 28, 2025
c41cc19
Update
vkuzo Mar 28, 2025
8ba3018
Update
vkuzo Mar 28, 2025
f437e00
Update
vkuzo Mar 28, 2025
a62e00b
Update
vkuzo Mar 28, 2025
195d904
Update
vkuzo Mar 28, 2025
bdb0996
Update
vkuzo Mar 28, 2025
8a5050e
Update
vkuzo Apr 1, 2025
635ca65
Update
vkuzo Apr 1, 2025
0d5b763
Update
vkuzo Apr 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/mx_formats/cast_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def run(
)

assert y_d0.dtype == torch.float8_e4m3fn
assert s_d0.dtype == torch.uint8
assert s_d0.dtype == torch.float8_e8m0fnu
bytes_r = x.numel() * bytes_per_el_bf16
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)
Expand All @@ -166,7 +166,7 @@ def run(
)

assert y_d1.dtype == torch.float8_e4m3fn
assert s_d1.dtype == torch.uint8
assert s_d1.dtype == torch.float8_e8m0fnu
bytes_r = x.numel() * bytes_per_el_bf16
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)
Expand Down
4 changes: 1 addition & 3 deletions test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,9 +459,7 @@ def test_fp6_e3m2_pack_unpack():
)
@pytest.mark.parametrize("M", (256, 2048))
@pytest.mark.parametrize("K", (256, 2048))
# @pytest.mark.parametrize("M", (256,))
# @pytest.mark.parametrize("K", (256,))
def test_triton_mxfp8_dim1(M, K):
def test_triton_mxfp8_dim1_randn(M, K):
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32)
x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32)
Expand Down
4 changes: 3 additions & 1 deletion torchao/prototype/mx_formats/custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,9 @@ def _triton_calculate_scale(x, axis):
scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias
scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8)

# TODO(future PR): add NaN handling here
# TODO(future PR): add NaN handling here,
# https://github.com/pytorch/pytorch/pull/100572 will likely be useful to
# get proper NaN propagation working

# Calculate the scale in floating point.
scale_fp = (scale_e8m0_biased.to(tl.int32) << fp32_mbits).to(
Expand Down
1 change: 0 additions & 1 deletion torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ def to_mx(
# Calculate the scale for different modes
max_abs_int32 = (max_abs + eps).view(hp_int_dtype)
extracted_pow2 = ((max_abs_int32 >> hp_mbits) & 0b11111111) - hp_exp_bias
extracted_pow2 = extracted_pow2.to(data_hp.dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the main change in this PR, the rest are unrelated cleanups


if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN):
scale_e8m0_unbiased = extracted_pow2 - target_max_pow2
Expand Down
Loading