Skip to content

Commit 83a20c7

Browse files
[mxfp8 moe training] add triton kernel for blocked swizzled 3d weight scales (#2894)
1 parent 6176322 commit 83a20c7

File tree

4 files changed

+307
-1
lines changed

4 files changed

+307
-1
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
from dataclasses import dataclass
9+
from typing import List
10+
11+
import torch
12+
from tabulate import tabulate
13+
from tqdm import tqdm
14+
15+
from benchmarks.utils import benchmark_cuda_function_in_microseconds
16+
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
17+
torch_to_blocked_per_group_3d,
18+
triton_mx_block_rearrange_per_group_3d,
19+
)
20+
21+
device = torch.device("cuda")
22+
23+
# Needed since changing args to function causes recompiles
24+
torch._dynamo.config.cache_size_limit = 1000
25+
26+
27+
@dataclass(frozen=True)
28+
class ExperimentConfig:
29+
input_shape: tuple[int]
30+
31+
32+
@dataclass(frozen=True)
33+
class ExperimentResult:
34+
torch_time_us: float
35+
triton_time_us: float
36+
torch_mem_bw_gbps: float
37+
triton_mem_bw_gbps: float
38+
39+
40+
@dataclass(frozen=True)
41+
class Experiment:
42+
config: ExperimentConfig
43+
result: ExperimentResult
44+
45+
46+
def get_configs() -> List[ExperimentConfig]:
47+
# Llama4 shapes. Input activations are scaled along K dim.
48+
block_size = 32
49+
input_shapes = [
50+
# w1, w3 scaled along K (fwd)
51+
(1, 8192, 5120 // block_size),
52+
(2, 8192, 5120 // block_size),
53+
(4, 8192, 5120 // block_size),
54+
(8, 8192, 5120 // block_size),
55+
(16, 8192, 5120 // block_size),
56+
# w2 scaled along K (fwd)
57+
(1, 5120, 8192 // block_size),
58+
(2, 5120, 8192 // block_size),
59+
(4, 5120, 8192 // block_size),
60+
(8, 5120, 8192 // block_size),
61+
(16, 5120, 8192 // block_size),
62+
]
63+
configs = []
64+
for shape in input_shapes:
65+
configs.append(
66+
ExperimentConfig(
67+
input_shape=shape,
68+
)
69+
)
70+
return configs
71+
72+
73+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
74+
input_tensor = torch.randint(
75+
low=0,
76+
high=256,
77+
size=config.input_shape,
78+
dtype=torch.uint8,
79+
device=device,
80+
)
81+
82+
def warmup(fn, *args, **kwargs):
83+
for _ in range(5):
84+
fn(*args, **kwargs)
85+
86+
E, N, K = config.input_shape
87+
88+
# bench torch
89+
compiled_run_torch = torch.compile(torch_to_blocked_per_group_3d)
90+
warmup(compiled_run_torch, input_tensor)
91+
torch_time_us = benchmark_cuda_function_in_microseconds(
92+
compiled_run_torch,
93+
input_tensor,
94+
)
95+
96+
# bench triton
97+
triton_out_scales = triton_mx_block_rearrange_per_group_3d(input_tensor)
98+
warmup(triton_mx_block_rearrange_per_group_3d, input_tensor)
99+
triton_time_us = benchmark_cuda_function_in_microseconds(
100+
triton_mx_block_rearrange_per_group_3d,
101+
input_tensor,
102+
)
103+
104+
# mem bw calculations
105+
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
106+
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
107+
108+
read_bytes = input_tensor.numel() * bytes_per_input_el
109+
write_bytes = triton_out_scales.numel() * bytes_per_output_el
110+
111+
torch_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6)
112+
triton_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)
113+
114+
return ExperimentResult(
115+
torch_time_us=torch_time_us,
116+
triton_time_us=triton_time_us,
117+
torch_mem_bw_gbps=torch_mem_bw_gbps,
118+
triton_mem_bw_gbps=triton_mem_bw_gbps,
119+
)
120+
121+
122+
def print_results(experiments: List[Experiment]):
123+
headers = [
124+
"input_shape",
125+
"torch_time_us",
126+
"triton_time_us",
127+
"torch_mem_bw_gbps",
128+
"triton_mem_bw_gbps",
129+
"triton_speedup",
130+
]
131+
rows = []
132+
for experiment in experiments:
133+
input_shape = f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]}, {experiment.config.input_shape[2]})"
134+
rows.append(
135+
[
136+
input_shape,
137+
experiment.result.torch_time_us,
138+
experiment.result.triton_time_us,
139+
round(experiment.result.torch_mem_bw_gbps, 3),
140+
round(experiment.result.triton_mem_bw_gbps, 3),
141+
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
142+
]
143+
)
144+
print(tabulate(rows, headers=headers))
145+
146+
147+
def main():
148+
torch.random.manual_seed(123)
149+
configs = get_configs()
150+
results = []
151+
for config in tqdm(configs):
152+
result = run_experiment(config)
153+
results.append(Experiment(config=config, result=result))
154+
155+
# Use Tabulate to print results
156+
print_results(results)
157+
158+
159+
if __name__ == "__main__":
160+
main()

test/prototype/moe_training/test_kernels.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
2525
compute_per_group_blocked_scale_offsets,
2626
torch_to_blocked_per_group_2d,
27+
torch_to_blocked_per_group_3d,
2728
triton_mx_block_rearrange_per_group_2d,
29+
triton_mx_block_rearrange_per_group_3d,
2830
)
2931
from torchao.prototype.moe_training.utils import (
3032
_is_column_major,
@@ -240,3 +242,27 @@ def test_mxfp8_per_group_blocked_scales_2d(
240242
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
241243
"blocked scales not equal"
242244
)
245+
246+
247+
@skip_if_rocm("ROCm enablement in progress")
248+
@pytest.mark.parametrize("e,n,k", [(1, 8192, 5120), (2, 8192, 5120), (8, 5120, 8192)])
249+
def test_mxfp8_per_group_blocked_scales_3d(
250+
e: int,
251+
n: int,
252+
k: int,
253+
):
254+
device = "cuda"
255+
block_size = 32
256+
weights = torch.randn(e, n, k // block_size, device=device)
257+
weight_scales, _ = to_mx(
258+
weights, elem_dtype=torch.float8_e4m3fn, block_size=block_size
259+
)
260+
261+
# torch reference
262+
ref_out_scales = torch_to_blocked_per_group_3d(weight_scales)
263+
264+
# triton kernel
265+
triton_out_scales = triton_mx_block_rearrange_per_group_3d(weight_scales)
266+
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
267+
"blocked scales not equal"
268+
)

torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def triton_scale_swizzle_per_group_2d(
211211
# We track how many row blocks we have iterated through.
212212
block_row_id = 0
213213
current_start_row = input_group_start_row
214+
214215
# TODO: Investigate if it is possible and beneficial to parallelize along
215216
# row blocks as well, and get rid of this loop.
216217
while current_start_row < input_group_end_row:
@@ -237,3 +238,122 @@ def triton_scale_swizzle_per_group_2d(
237238
# Update row block id to next block
238239
block_row_id += 1
239240
current_start_row += BLOCK_ROWS
241+
242+
243+
def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.Tensor:
244+
"""
245+
Rearranges an E8M0 tensor scale to block-scaled swizzle format.
246+
247+
This format is suitable for Tmem as described in NVIDIA documentation:
248+
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
249+
250+
Args:
251+
scale_tensor: Input tensor in row-major format with 8-bit elements
252+
253+
Returns:
254+
Rearranged tensor in block-scaled swizzle format
255+
"""
256+
assert scale_tensor.ndim == 3, "scales tensor must be 3d"
257+
assert scale_tensor.element_size() == 1, (
258+
"Expected element size to be 1 byte (8 bits)"
259+
)
260+
261+
num_groups, rows, cols = scale_tensor.shape
262+
input_stride_dim0 = scale_tensor.stride(0)
263+
input_stride_dim1 = scale_tensor.stride(1)
264+
input_stride_dim2 = scale_tensor.stride(2)
265+
266+
# Calculate blocks needed and allocate output tensor
267+
num_row_blocks = triton.cdiv(rows, 128)
268+
num_col_blocks = triton.cdiv(cols, 4)
269+
padded_rows = num_row_blocks * 128
270+
padded_cols = num_col_blocks * 4
271+
output = scale_tensor.new_empty((num_groups, padded_rows * padded_cols))
272+
output_stride_dim0 = output.stride(0)
273+
274+
# We probably want handle multiple blocks per tile but for now keep it simple
275+
BLOCK_ROWS, BLOCK_COLS = 128, 4
276+
277+
# Output block stride for the rearranged format
278+
output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS)
279+
280+
grid = lambda META: (
281+
num_groups,
282+
num_row_blocks,
283+
num_col_blocks,
284+
)
285+
286+
triton_scale_swizzle_per_group_3d[grid](
287+
scale_tensor.view(torch.uint8),
288+
input_stride_dim0,
289+
input_stride_dim1,
290+
input_stride_dim2,
291+
output.view(torch.uint8),
292+
output_stride_dim0,
293+
output_block_stride,
294+
rows,
295+
cols,
296+
BLOCK_ROWS=BLOCK_ROWS,
297+
BLOCK_COLS=BLOCK_COLS,
298+
)
299+
300+
return output
301+
302+
303+
@triton.jit
304+
def triton_scale_swizzle_per_group_3d(
305+
input_ptr,
306+
input_stride_dim0,
307+
input_stride_dim1,
308+
input_stride_dim2,
309+
output_ptr,
310+
output_stride_dim0,
311+
output_block_stride,
312+
scale_rows,
313+
scale_cols,
314+
BLOCK_ROWS: tl.constexpr,
315+
BLOCK_COLS: tl.constexpr,
316+
):
317+
pid_group = tl.program_id(0)
318+
pid_row = tl.program_id(1)
319+
pid_col = tl.program_id(2)
320+
321+
# Update base pointers based on this group id
322+
input_ptr += pid_group * input_stride_dim0
323+
output_ptr += pid_group * output_stride_dim0
324+
325+
rows = tl.arange(0, BLOCK_ROWS)[:, None]
326+
cols = tl.arange(0, BLOCK_COLS)[None, :]
327+
328+
# Calculate starting row and column for this tile
329+
start_row = pid_row * BLOCK_ROWS
330+
start_col = pid_col * BLOCK_COLS
331+
global_rows = start_row + rows
332+
global_cols = start_col + cols
333+
334+
mask = (global_rows < scale_rows) & (global_cols < scale_cols)
335+
336+
input_scales = tl.load(
337+
input_ptr + global_rows * input_stride_dim1 + global_cols * input_stride_dim2,
338+
mask=mask,
339+
other=0.0,
340+
)
341+
342+
r_div_32 = rows // 32
343+
r_mod_32 = rows % 32
344+
345+
# 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
346+
dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols
347+
348+
# Flatten
349+
dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS))
350+
scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS))
351+
352+
# Calculate block offset using provided output block stride
353+
LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
354+
block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride)
355+
356+
tl.store(
357+
output_ptr + block_offset + dest_indices_flat,
358+
scales_flat,
359+
)

torchao/prototype/mx_formats/kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1179,7 +1179,7 @@ def triton_scale_swizzle(
11791179
@torch.library.custom_op("torchao::triton_mx_block_rearrange", mutates_args=())
11801180
def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
11811181
"""
1182-
Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format.
1182+
Rearranges an E8M0 tensor scale to block-scaled swizzle format.
11831183
11841184
This format is suitable for Tmem as described in NVIDIA documentation:
11851185
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout

0 commit comments

Comments
 (0)