Skip to content

Commit c62b0f0

Browse files
[mxfp8 moe training] add triton kernel for mxfp8 quantization along dim0
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
1 parent cd21d0e commit c62b0f0

File tree

3 files changed

+220
-0
lines changed

3 files changed

+220
-0
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from torchao.prototype.mx_formats.config import ScaleCalculationMode
1515
from torchao.prototype.mx_formats.kernels import (
16+
triton_to_mxfp8_dim0,
1617
triton_to_mxfp8_dim1,
1718
)
1819
from torchao.prototype.mx_formats.mx_tensor import to_mx
@@ -97,6 +98,7 @@ def run(
9798
"dim0_mxfp8_floor",
9899
"dim0_mxfp4_floor",
99100
"dim0_mxfp8_rceil",
101+
"dim0_mxfp8_triton_floor",
100102
"dim1_mxfp8_floor",
101103
"dim1_mxfp8_rceil",
102104
"dim1_mxfp8_triton_floor",
@@ -222,6 +224,22 @@ def run(
222224
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
223225
bps = (bytes_r + bytes_w) / (time_us / 1e6)
224226

227+
elif mode == "dim0_mxfp8_triton_floor":
228+
y_d0, s_d0 = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)
229+
230+
for _ in range(2):
231+
__ = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)
232+
time_us = benchmark_cuda_function_in_microseconds(
233+
lambda x, b: triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE),
234+
x,
235+
BLOCK_SIZE,
236+
)
237+
assert y_d0.dtype == torch.float8_e4m3fn
238+
assert s_d0.dtype == torch.float8_e8m0fnu
239+
bytes_r = x.numel() * bytes_per_el_bf16
240+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
241+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
242+
225243
elif mode == "dim1_mxfp8_floor":
226244
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
227245
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)

test/prototype/mx_formats/test_kernels.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
pack_uint6,
3838
triton_f6_e2m3_to_bf16,
3939
triton_f6_e3m2_to_bf16,
40+
triton_to_mxfp8_dim0,
4041
triton_to_mxfp8_dim1,
4142
triton_to_mxfp8_dim1_reference,
4243
unpack_uint4,
@@ -431,6 +432,23 @@ def test_fp6_e3m2_pack_unpack():
431432
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)
432433

433434

435+
def triton_to_mxfp8_dim0_reference(
436+
x_hp: torch.Tensor, block_size
437+
) -> tuple[torch.Tensor, torch.Tensor]:
438+
"""
439+
A reference version of `triton_to_mxfp8_dim0` for rowwise quantization.
440+
"""
441+
from torchao.prototype.mx_formats.mx_tensor import to_mx
442+
443+
# cast across dim0 (rowwise) - no transpose needed
444+
scale_e8m0_dim0, x_hp_d0_normalized = to_mx(x_hp, torch.float8_e4m3fn, block_size)
445+
scale_e8m0_dim0 = scale_e8m0_dim0.view(torch.float8_e8m0fnu)
446+
return (
447+
x_hp_d0_normalized,
448+
scale_e8m0_dim0.unsqueeze(-1),
449+
)
450+
451+
434452
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
435453
@pytest.mark.skipif(
436454
not is_sm_at_least_89(),
@@ -446,6 +464,21 @@ def test_triton_mxfp8_dim1_randn(M, K):
446464
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
447465

448466

467+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
468+
@pytest.mark.skipif(
469+
not is_sm_at_least_89(),
470+
reason="float8 in triton requires CUDA capability 8.9 or greater",
471+
)
472+
@pytest.mark.parametrize("M", (256, 2048))
473+
@pytest.mark.parametrize("K", (256, 2048))
474+
def test_triton_mxfp8_dim0_randn(M, K):
475+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
476+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
477+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
478+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
479+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
480+
481+
449482
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
450483
@pytest.mark.parametrize(
451484
"shape",

torchao/prototype/mx_formats/kernels.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,175 @@ def to_mxfp8_dim1_kernel(
10381038
# TODO(future): mask this store
10391039
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)
10401040

1041+
@triton.autotune(
1042+
configs=_get_mxfp8_dim1_kernel_autotune_configs(),
1043+
key=["n_rows", "n_cols", "INNER_BLOCK_SIZE"],
1044+
)
1045+
@triton.jit
1046+
def to_mxfp8_dim0_kernel(
1047+
x_ptr, # pointer to input tensor
1048+
output_ptr, # pointer to output tensor (row-normalized)
1049+
row_scale_ptr, # pointer to store row-wise maximum absolute values
1050+
n_rows, # number of rows in the tensor
1051+
n_cols, # number of columns in the tensor
1052+
ROW_TILE_SIZE: tl.constexpr,
1053+
COL_TILE_SIZE: tl.constexpr,
1054+
INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX
1055+
):
1056+
"""
1057+
Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity).
1058+
1059+
This is the counterpart to to_mxfp8_dim1_kernel which does columnwise quantization.
1060+
Instead of transposing and scaling across columns, this kernel scales across rows.
1061+
"""
1062+
1063+
BLOCKS_PER_COL_TILE: tl.constexpr = COL_TILE_SIZE // INNER_BLOCK_SIZE
1064+
1065+
# Get program ID
1066+
pid_row = tl.program_id(0)
1067+
pid_col = tl.program_id(1)
1068+
1069+
# Calculate starting row and column for this tile
1070+
start_row = pid_row * ROW_TILE_SIZE
1071+
start_col = pid_col * COL_TILE_SIZE
1072+
1073+
# Create offsets for the block
1074+
row_offsets = tl.arange(0, ROW_TILE_SIZE)
1075+
col_offsets = tl.arange(0, COL_TILE_SIZE)
1076+
1077+
# Compute global row/col positions
1078+
rows = start_row + row_offsets[:, None]
1079+
cols = start_col + col_offsets[None, :]
1080+
1081+
# Create masks for out-of-bounds accesses
1082+
row_mask = rows < n_rows
1083+
col_mask = cols < n_cols
1084+
mask = row_mask & col_mask
1085+
1086+
# Compute memory offsets for row-major layout (rows, cols)
1087+
row_major_offsets = (rows * n_cols + cols).to(tl.int32)
1088+
1089+
# Load the entire block in a single operation
1090+
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE)
1091+
x_block = tl.load(x_ptr + row_major_offsets, mask=mask)
1092+
1093+
# Reshape to inner tile size for rowwise scaling
1094+
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1095+
x_block_r = x_block.reshape(
1096+
ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE
1097+
)
1098+
1099+
# Calculate the absolute values of elements in the block
1100+
x_block_abs_r = tl.abs(x_block_r)
1101+
1102+
# Find the maximum absolute value for each row (across columns)
1103+
# shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1104+
row_scale_r, row_scale_e8m0_r = _triton_calculate_scale(x_block_abs_r, axis=1)
1105+
1106+
# Divide each row by scale
1107+
# Broadcasting row_scale to match x_block's shape
1108+
# x_block_r shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1109+
# row_scale shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1)
1110+
row_normalized_r = x_block_r / row_scale_r[:, None]
1111+
1112+
# Reshape back to original tile size
1113+
row_normalized = tl.reshape(row_normalized_r, ROW_TILE_SIZE, COL_TILE_SIZE)
1114+
1115+
# Quantize to float8
1116+
row_normalized = row_normalized.to(tl.float8e4nv)
1117+
1118+
# Store the row-normalized result in row-major format
1119+
tl.store(output_ptr + row_major_offsets, row_normalized, mask=mask)
1120+
1121+
# reshape row_scale_e8m0_r for proper storage
1122+
# shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1123+
row_scale_e8m0 = row_scale_e8m0_r.reshape(ROW_TILE_SIZE * BLOCKS_PER_COL_TILE)
1124+
1125+
row_scale_start_offsets = (
1126+
(pid_row * ROW_TILE_SIZE * (n_cols // COL_TILE_SIZE))
1127+
* BLOCKS_PER_COL_TILE # number of blocks seen so far
1128+
+ pid_col * BLOCKS_PER_COL_TILE # increment BLOCKS_PER_COL_TILE
1129+
)
1130+
1131+
row_scale_start_ptr = row_scale_ptr + row_scale_start_offsets
1132+
1133+
# calculate row_scale_indices
1134+
row_scale_indices = tl.arange(0, ROW_TILE_SIZE * BLOCKS_PER_COL_TILE)
1135+
1136+
# How many values are in all the other rows for this col_pid, need to jump
1137+
# over them for every BLOCKS_PER_COL_TILE values
1138+
jump_vals_per_row = (n_cols - COL_TILE_SIZE) // INNER_BLOCK_SIZE
1139+
1140+
# example transformation (specifics depend on tile sizes):
1141+
# [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13]
1142+
row_scale_indices = row_scale_indices + (
1143+
(row_scale_indices // BLOCKS_PER_COL_TILE) * jump_vals_per_row
1144+
)
1145+
1146+
# Store the scales
1147+
tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0)
1148+
1149+
@triton_op("torchao::triton_to_mxfp8_dim0", mutates_args={})
1150+
def triton_to_mxfp8_dim0(
1151+
x: torch.Tensor, inner_block_size: int = 32
1152+
) -> Tuple[torch.Tensor, torch.Tensor]:
1153+
"""
1154+
Input:
1155+
* `x` - input tensor, in row major memory layout
1156+
* `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes
1157+
1158+
Output:
1159+
* `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise)
1160+
* `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
1161+
"""
1162+
assert x.is_contiguous(), "`x` must be contiguous"
1163+
assert inner_block_size <= 32
1164+
1165+
# Get tensor shape
1166+
n_rows, n_cols = x.shape
1167+
1168+
# Masking of loads and stores is not well tested yet, so for now enforce
1169+
# shapes which do not need masking. Note that this condition depends on max values of
1170+
# ROW_TILE_SIZE and COL_TILE_SIZE, which are autotuned above.
1171+
# TODO(future): implement and test masking and remove this restriction
1172+
max_row_tile_size = 128
1173+
max_col_tile_size = 128
1174+
assert n_rows % max_row_tile_size == 0, "unsupported"
1175+
assert n_cols % max_col_tile_size == 0, "unsupported"
1176+
1177+
# Create output tensors
1178+
output = torch.empty(
1179+
(n_rows, n_cols), dtype=torch.float8_e4m3fn, device=x.device
1180+
)
1181+
1182+
# Create scale tensors for rowwise scaling
1183+
row_scale = torch.empty(
1184+
(n_rows, n_cols // inner_block_size, 1),
1185+
dtype=torch.uint8,
1186+
device=x.device,
1187+
)
1188+
1189+
# Calculate grid dimensions based on tile size
1190+
grid = lambda META: (
1191+
triton.cdiv(n_rows, META["ROW_TILE_SIZE"]),
1192+
triton.cdiv(n_cols, META["COL_TILE_SIZE"]),
1193+
)
1194+
1195+
# Launch the kernel
1196+
wrap_triton(to_mxfp8_dim0_kernel)[grid](
1197+
x_ptr=x,
1198+
output_ptr=output,
1199+
row_scale_ptr=row_scale,
1200+
n_rows=n_rows,
1201+
n_cols=n_cols,
1202+
INNER_BLOCK_SIZE=inner_block_size,
1203+
)
1204+
1205+
return (
1206+
output,
1207+
row_scale.view(torch.float8_e8m0fnu),
1208+
)
1209+
10411210
@triton_op("torchao::triton_to_mxfp8_dim1", mutates_args={})
10421211
def triton_to_mxfp8_dim1(
10431212
x: torch.Tensor, inner_block_size: int = 32

0 commit comments

Comments
 (0)