Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
import torch
from tabulate import tabulate
from tqdm import tqdm
from utils import benchmark_cuda_function_in_microseconds

from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.float8.config import ScalingGranularity
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
torch_to_blocked_per_group_2d,
torch_to_blocked_per_group_3d,
)
from torchao.prototype.moe_training.utils import generate_jagged_offs
from torchao.prototype.mx_formats.mx_tensor import to_mx
from torchao.prototype.mx_formats.utils import (
to_blocked_per_group_2d,
to_blocked_per_group_3d,
)

device = torch.device("cuda")

Expand Down Expand Up @@ -50,9 +50,9 @@ class Experiment:
def get_configs() -> List[ExperimentConfig]:
# Llama4 shapes
M = [16640]
K = [5120]
N = [8192]
E = [16]
K = [2048, 5120, 8192]
N = [2048, 5120, 8192]
E = [1, 2, 4, 8]
configs = []
for e, m, n, k in itertools.product(
E,
Expand Down Expand Up @@ -196,10 +196,10 @@ def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float:

# Convert scales for each group to blocked format.
Mg, K = A_fp8.shape
A_scales_blocked, starting_row_after_padding = to_blocked_per_group_2d(
A_scales_blocked, starting_row_after_padding = torch_to_blocked_per_group_2d(
A_scales, offs, Mg, K
)
B_scales_blocked = to_blocked_per_group_3d(B_scales)
B_scales_blocked = torch_to_blocked_per_group_3d(B_scales)

# From this, we compute `group_sizes` and `starting_row_after_padding`:
# group_sizes = [32, 32, 64]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py

import itertools
from dataclasses import dataclass
from typing import List

import torch
from tabulate import tabulate
from tqdm import tqdm

from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
compute_per_group_blocked_scale_offsets,
torch_to_blocked_per_group_2d,
triton_mx_block_rearrange_per_group_2d,
)
from torchao.prototype.moe_training.utils import generate_jagged_offs

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
input_shape: tuple[int]
num_groups: int


@dataclass(frozen=True)
class ExperimentResult:
torch_time_us: float
triton_time_us: float
torch_mem_bw_gbps: float
triton_mem_bw_gbps: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
# Llama4 shapes. Input activations are scaled along K dim.
block_size = 32
input_shapes = [
(16640, 5120 // block_size),
]
num_groups = [16]
configs = []
for shape, groups in itertools.product(
input_shapes,
num_groups,
):
configs.append(
ExperimentConfig(
input_shape=shape,
num_groups=groups,
)
)
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
input_shape, num_groups = config.input_shape, config.num_groups
input_tensor = torch.randint(
low=0,
high=256,
size=input_shape,
dtype=torch.uint8,
device=device,
)

Mg, K = input_shape
input_group_offsets = generate_jagged_offs(num_groups, Mg, multiple_of=32)

# bench torch
compiled_run_torch = torch.compile(torch_to_blocked_per_group_2d)
torch_out_scales, torch_group_offs = compiled_run_torch(
input_tensor, input_group_offsets, Mg, K
)
torch_time_us = benchmark_cuda_function_in_microseconds(
compiled_run_torch,
input_tensor,
input_group_offsets,
Mg,
K,
)

# bench triton
_, output_group_offsets = compute_per_group_blocked_scale_offsets(
input_group_offsets
)
triton_out_scales = triton_mx_block_rearrange_per_group_2d(
input_tensor,
input_group_offsets,
output_group_offsets,
)
triton_time_us = benchmark_cuda_function_in_microseconds(
triton_mx_block_rearrange_per_group_2d,
input_tensor,
input_group_offsets,
output_group_offsets,
)

# mem bw calculations
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8

read_bytes = input_tensor.numel() * bytes_per_input_el
write_bytes = triton_out_scales.numel() * bytes_per_output_el

torch_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6)
triton_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)

return ExperimentResult(
torch_time_us=torch_time_us,
triton_time_us=triton_time_us,
torch_mem_bw_gbps=torch_mem_bw_gbps,
triton_mem_bw_gbps=triton_mem_bw_gbps,
)


def print_results(experiments: List[Experiment]):
headers = [
"input_shape",
"torch_time_us",
"triton_time_us",
"torch_mem_bw_gbps",
"triton_mem_bw_gbps",
"triton_speedup",
]
rows = []
for experiment in experiments:
input_shape = (
f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})"
)
rows.append(
[
input_shape,
experiment.result.torch_time_us,
experiment.result.triton_time_us,
round(experiment.result.torch_mem_bw_gbps, 3),
round(experiment.result.triton_mem_bw_gbps, 3),
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
]
)
print(tabulate(rows, headers=headers))


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
main()
45 changes: 45 additions & 0 deletions test/prototype/moe_training/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,19 @@
triton_fp8_per_group_colwise_scales,
triton_fp8_per_group_rowwise_scales,
)
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
compute_per_group_blocked_scale_offsets,
torch_to_blocked_per_group_2d,
triton_mx_block_rearrange_per_group_2d,
)
from torchao.prototype.moe_training.utils import (
_is_column_major,
generate_jagged_offs,
torch_to_3d_rowwise_float8_transpose_rhs,
torch_to_float8_per_group_colwise,
torch_to_float8_per_group_rowwise,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx
from torchao.testing.utils import skip_if_rocm


Expand Down Expand Up @@ -195,3 +202,41 @@ def test_fp8_rowwise_3d_transpose_rhs_reduction(round_scales_to_power_of_2: bool
assert ref_fp8.shape == triton_fp8.shape, "output shapes not equal"
assert ref_fp8.stride() == triton_fp8.stride(), "output strides not equal"
assert torch.allclose(ref_fp8, triton_fp8, rtol=0, atol=0), "fp8 data not equal"


@skip_if_rocm("ROCm enablement in progress")
@pytest.mark.parametrize(
"m,k,n_groups", [(256, 256, 4), (16640, 5120, 16), (16640, 8192, 16)]
)
def test_mxfp8_per_group_blocked_scales_2d(
m: int,
k: int,
n_groups: int,
):
device = "cuda"
block_size = 32
input_data = torch.randn(m, k, device=device)
e8m0_scales, _ = to_mx(
input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size
)
input_group_offsets = generate_jagged_offs(
n_groups, m, multiple_of=block_size, device=device
)

# torch reference
ref_out_scales, _ = torch_to_blocked_per_group_2d(
e8m0_scales, input_group_offsets, m, k, block_size=block_size
)

# triton kernel
_, output_group_offsets = compute_per_group_blocked_scale_offsets(
input_group_offsets
)
triton_out_scales = triton_mx_block_rearrange_per_group_2d(
e8m0_scales,
input_group_offsets,
output_group_offsets,
)
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
"blocked scales not equal"
)
2 changes: 1 addition & 1 deletion torchao/prototype/moe_training/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales,
)
from torchao.prototype.moe_training.kernels.mxfp8 import (
from torchao.prototype.moe_training.kernels.mxfp8_gemms import (
fbgemm_mxfp8_grouped_mm_2d_3d as fbgemm_mxfp8_grouped_mm_2d_3d,
)
Loading
Loading