Skip to content

Commit 72933f9

Browse files
[mxfp8 moe training] add triton kernel for mxfp8 quantization along dim0
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
1 parent cd21d0e commit 72933f9

File tree

3 files changed

+230
-2
lines changed

3 files changed

+230
-2
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from torchao.prototype.mx_formats.config import ScaleCalculationMode
1515
from torchao.prototype.mx_formats.kernels import (
16+
triton_to_mxfp8_dim0,
1617
triton_to_mxfp8_dim1,
1718
)
1819
from torchao.prototype.mx_formats.mx_tensor import to_mx
@@ -97,6 +98,7 @@ def run(
9798
"dim0_mxfp8_floor",
9899
"dim0_mxfp4_floor",
99100
"dim0_mxfp8_rceil",
101+
"dim0_mxfp8_triton_floor",
100102
"dim1_mxfp8_floor",
101103
"dim1_mxfp8_rceil",
102104
"dim1_mxfp8_triton_floor",
@@ -222,6 +224,22 @@ def run(
222224
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
223225
bps = (bytes_r + bytes_w) / (time_us / 1e6)
224226

227+
elif mode == "dim0_mxfp8_triton_floor":
228+
y_d0, s_d0 = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)
229+
230+
for _ in range(2):
231+
__ = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)
232+
time_us = benchmark_cuda_function_in_microseconds(
233+
lambda x, b: triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE),
234+
x,
235+
BLOCK_SIZE,
236+
)
237+
assert y_d0.dtype == torch.float8_e4m3fn
238+
assert s_d0.dtype == torch.float8_e8m0fnu
239+
bytes_r = x.numel() * bytes_per_el_bf16
240+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
241+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
242+
225243
elif mode == "dim1_mxfp8_floor":
226244
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
227245
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)

test/prototype/mx_formats/test_kernels.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
pack_uint6,
3838
triton_f6_e2m3_to_bf16,
3939
triton_f6_e3m2_to_bf16,
40+
triton_to_mxfp8_dim0,
4041
triton_to_mxfp8_dim1,
4142
triton_to_mxfp8_dim1_reference,
4243
unpack_uint4,
@@ -431,6 +432,23 @@ def test_fp6_e3m2_pack_unpack():
431432
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)
432433

433434

435+
def triton_to_mxfp8_dim0_reference(
436+
x_hp: torch.Tensor, block_size
437+
) -> tuple[torch.Tensor, torch.Tensor]:
438+
"""
439+
A reference version of `triton_to_mxfp8_dim0` for rowwise quantization.
440+
"""
441+
from torchao.prototype.mx_formats.mx_tensor import to_mx
442+
443+
# cast across dim0 (rowwise) - no transpose needed
444+
scale_e8m0_dim0, x_hp_d0_normalized = to_mx(x_hp, torch.float8_e4m3fn, block_size)
445+
scale_e8m0_dim0 = scale_e8m0_dim0.view(torch.float8_e8m0fnu)
446+
return (
447+
x_hp_d0_normalized,
448+
scale_e8m0_dim0.unsqueeze(-1),
449+
)
450+
451+
434452
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
435453
@pytest.mark.skipif(
436454
not is_sm_at_least_89(),
@@ -446,6 +464,21 @@ def test_triton_mxfp8_dim1_randn(M, K):
446464
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
447465

448466

467+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
468+
@pytest.mark.skipif(
469+
not is_sm_at_least_89(),
470+
reason="float8 in triton requires CUDA capability 8.9 or greater",
471+
)
472+
@pytest.mark.parametrize("M", (256, 2048, 131072))
473+
@pytest.mark.parametrize("K", (256, 5120, 7168))
474+
def test_triton_mxfp8_dim0_randn(M, K):
475+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
476+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
477+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
478+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
479+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
480+
481+
449482
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
450483
@pytest.mark.parametrize(
451484
"shape",

torchao/prototype/mx_formats/kernels.py

Lines changed: 179 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,8 @@ def _(uint8_data):
829829
import triton.language as tl
830830
from torch.library import triton_op, wrap_triton
831831

832+
print("importing triton ops")
833+
832834
@triton.jit
833835
def _triton_calculate_scale(x, axis):
834836
# There is no good support for accessing globals from a jit'ed triton
@@ -891,13 +893,13 @@ def _get_mxfp8_dim1_kernel_autotune_configs():
891893

892894
@triton.autotune(
893895
configs=_get_mxfp8_dim1_kernel_autotune_configs(),
894-
key=["n_rows", "n_cols", "INNER_BLOCK_SIZE"],
896+
key=["n_cols", "INNER_BLOCK_SIZE"],
895897
)
896898
@triton.jit
897899
def to_mxfp8_dim1_kernel(
898900
x_ptr, # pointer to input tensor
899901
output_col_major_ptr, # pointer to column-major output tensor (column-normalized)
900-
col_scale_ptr, # pointer to store column-wise maximum absolute values
902+
col_scale_ptr, # pointer to store scales
901903
n_rows, # number of rows in the tensor
902904
n_cols, # number of columns in the tensor
903905
ROW_TILE_SIZE: tl.constexpr,
@@ -1038,6 +1040,175 @@ def to_mxfp8_dim1_kernel(
10381040
# TODO(future): mask this store
10391041
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)
10401042

1043+
@triton.autotune(
1044+
configs=_get_mxfp8_dim1_kernel_autotune_configs(),
1045+
key=["n_rows", "n_cols", "INNER_BLOCK_SIZE"],
1046+
)
1047+
@triton.jit
1048+
def to_mxfp8_dim0_kernel(
1049+
x_ptr,
1050+
output_ptr,
1051+
row_scale_ptr,
1052+
n_rows,
1053+
n_cols,
1054+
ROW_TILE_SIZE: tl.constexpr,
1055+
COL_TILE_SIZE: tl.constexpr,
1056+
INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX
1057+
):
1058+
"""
1059+
Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity).
1060+
1061+
This is the counterpart to to_mxfp8_dim1_kernel which does columnwise quantization.
1062+
Instead of transposing and scaling across columns, this kernel scales across rows.
1063+
"""
1064+
1065+
BLOCKS_PER_COL_TILE: tl.constexpr = COL_TILE_SIZE // INNER_BLOCK_SIZE
1066+
1067+
# Get program ID
1068+
pid_row = tl.program_id(0)
1069+
pid_col = tl.program_id(1)
1070+
1071+
# Calculate starting row and column for this tile
1072+
start_row = pid_row * ROW_TILE_SIZE
1073+
start_col = pid_col * COL_TILE_SIZE
1074+
1075+
# Create offsets for the block
1076+
row_offsets = tl.arange(0, ROW_TILE_SIZE)
1077+
col_offsets = tl.arange(0, COL_TILE_SIZE)
1078+
1079+
# Compute global row/col positions
1080+
rows = start_row + row_offsets[:, None]
1081+
cols = start_col + col_offsets[None, :]
1082+
1083+
# Create masks for out-of-bounds accesses
1084+
row_mask = rows < n_rows
1085+
col_mask = cols < n_cols
1086+
mask = row_mask & col_mask
1087+
1088+
# Compute memory offsets for row-major layout (rows, cols)
1089+
row_major_offsets = (rows * n_cols + cols).to(tl.int32)
1090+
1091+
# Load the entire block in a single operation
1092+
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE)
1093+
x_block = tl.load(x_ptr + row_major_offsets, mask=mask)
1094+
1095+
# Reshape to inner tile size for rowwise scaling
1096+
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1097+
x_block_r = x_block.reshape(
1098+
ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE
1099+
)
1100+
1101+
# Calculate the absolute values of elements in the block
1102+
x_block_abs_r = tl.abs(x_block_r)
1103+
1104+
# Find the maximum absolute value for each row (across columns)
1105+
# shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1106+
row_scale_r, row_scale_e8m0_r = _triton_calculate_scale(x_block_abs_r, axis=1)
1107+
1108+
# Divide each row by scale
1109+
# Broadcasting row_scale to match x_block's shape
1110+
# x_block_r shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1111+
# row_scale shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1)
1112+
row_normalized_r = x_block_r / row_scale_r[:, None]
1113+
1114+
# Reshape back to original tile size
1115+
row_normalized = tl.reshape(row_normalized_r, ROW_TILE_SIZE, COL_TILE_SIZE)
1116+
1117+
# Quantize to float8
1118+
row_normalized = row_normalized.to(tl.float8e4nv)
1119+
1120+
# Store the row-normalized result in row-major format
1121+
tl.store(output_ptr + row_major_offsets, row_normalized, mask=mask)
1122+
1123+
# reshape row_scale_e8m0_r for proper storage
1124+
# shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1125+
row_scale_e8m0 = row_scale_e8m0_r.reshape(ROW_TILE_SIZE * BLOCKS_PER_COL_TILE)
1126+
1127+
row_scale_start_offsets = (
1128+
(pid_row * ROW_TILE_SIZE * (n_cols // COL_TILE_SIZE))
1129+
* BLOCKS_PER_COL_TILE # number of blocks seen so far
1130+
+ pid_col * BLOCKS_PER_COL_TILE # increment BLOCKS_PER_COL_TILE
1131+
)
1132+
1133+
row_scale_start_ptr = row_scale_ptr + row_scale_start_offsets
1134+
1135+
# calculate row_scale_indices
1136+
row_scale_indices = tl.arange(0, ROW_TILE_SIZE * BLOCKS_PER_COL_TILE)
1137+
1138+
# How many values are in all the other rows for this col_pid, need to jump
1139+
# over them for every BLOCKS_PER_COL_TILE values
1140+
jump_vals_per_row = (n_cols - COL_TILE_SIZE) // INNER_BLOCK_SIZE
1141+
1142+
# example transformation (specifics depend on tile sizes):
1143+
# [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13]
1144+
row_scale_indices = row_scale_indices + (
1145+
(row_scale_indices // BLOCKS_PER_COL_TILE) * jump_vals_per_row
1146+
)
1147+
1148+
# Store the scales
1149+
tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0)
1150+
1151+
@triton_op("torchao::triton_to_mxfp8_dim0", mutates_args={})
1152+
def triton_to_mxfp8_dim0(
1153+
x: torch.Tensor, inner_block_size: int = 32
1154+
) -> Tuple[torch.Tensor, torch.Tensor]:
1155+
"""
1156+
Input:
1157+
* `x` - input tensor, in row major memory layout
1158+
* `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes
1159+
1160+
Output:
1161+
* `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise)
1162+
* `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
1163+
"""
1164+
assert x.is_contiguous(), "`x` must be contiguous"
1165+
assert inner_block_size <= 32
1166+
1167+
# Get tensor shape
1168+
n_rows, n_cols = x.shape
1169+
1170+
# Masking of loads and stores is not well tested yet, so for now enforce
1171+
# shapes which do not need masking. Note that this condition depends on max values of
1172+
# ROW_TILE_SIZE and COL_TILE_SIZE, which are autotuned above.
1173+
# TODO(future): implement and test masking and remove this restriction
1174+
max_row_tile_size = 128
1175+
max_col_tile_size = 128
1176+
assert n_rows % max_row_tile_size == 0, "unsupported"
1177+
assert n_cols % max_col_tile_size == 0, "unsupported"
1178+
1179+
# Create output tensors
1180+
output = torch.empty(
1181+
(n_rows, n_cols), dtype=torch.float8_e4m3fn, device=x.device
1182+
)
1183+
1184+
# Create scale tensors for rowwise scaling
1185+
row_scale = torch.empty(
1186+
(n_rows, n_cols // inner_block_size, 1),
1187+
dtype=torch.uint8,
1188+
device=x.device,
1189+
)
1190+
1191+
# Calculate grid dimensions based on tile size
1192+
grid = lambda META: (
1193+
triton.cdiv(n_rows, META["ROW_TILE_SIZE"]),
1194+
triton.cdiv(n_cols, META["COL_TILE_SIZE"]),
1195+
)
1196+
1197+
# Launch the kernel
1198+
wrap_triton(to_mxfp8_dim0_kernel)[grid](
1199+
x_ptr=x,
1200+
output_ptr=output,
1201+
row_scale_ptr=row_scale,
1202+
n_rows=n_rows,
1203+
n_cols=n_cols,
1204+
INNER_BLOCK_SIZE=inner_block_size,
1205+
)
1206+
1207+
return (
1208+
output,
1209+
row_scale.view(torch.float8_e8m0fnu),
1210+
)
1211+
10411212
@triton_op("torchao::triton_to_mxfp8_dim1", mutates_args={})
10421213
def triton_to_mxfp8_dim1(
10431214
x: torch.Tensor, inner_block_size: int = 32
@@ -1459,6 +1630,12 @@ def _(scale_tensor):
14591630
return scale_tensor.new_empty((padded_rows, padded_cols))
14601631
else:
14611632

1633+
def triton_to_mxfp8_dim0(
1634+
x: torch.Tensor,
1635+
inner_block_size=32,
1636+
) -> Tuple[torch.Tensor, torch.Tensor]:
1637+
raise AssertionError("needs torch version 2.8+ and triton")
1638+
14621639
def triton_to_mxfp8_dim1(
14631640
x, inner_block_size=32
14641641
) -> Tuple[torch.Tensor, torch.Tensor]:

0 commit comments

Comments
 (0)