Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions benchmarks/mx_formats/cast_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from torchao.prototype.mx_formats.config import ScaleCalculationMode
from torchao.prototype.mx_formats.kernels import (
triton_to_mxfp8_dim0,
triton_to_mxfp8_dim1,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx
Expand Down Expand Up @@ -97,6 +98,7 @@ def run(
"dim0_mxfp8_floor",
"dim0_mxfp4_floor",
"dim0_mxfp8_rceil",
"dim0_mxfp8_triton_floor",
"dim1_mxfp8_floor",
"dim1_mxfp8_rceil",
"dim1_mxfp8_triton_floor",
Expand Down Expand Up @@ -222,6 +224,22 @@ def run(
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim0_mxfp8_triton_floor":
y_d0, s_d0 = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)

for _ in range(2):
__ = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)
time_us = benchmark_cuda_function_in_microseconds(
lambda x, b: triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE),
x,
BLOCK_SIZE,
)
assert y_d0.dtype == torch.float8_e4m3fn
assert s_d0.dtype == torch.float8_e8m0fnu
bytes_r = x.numel() * bytes_per_el_bf16
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim1_mxfp8_floor":
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)
Expand Down
33 changes: 33 additions & 0 deletions test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
pack_uint6,
triton_f6_e2m3_to_bf16,
triton_f6_e3m2_to_bf16,
triton_to_mxfp8_dim0,
triton_to_mxfp8_dim1,
triton_to_mxfp8_dim1_reference,
unpack_uint4,
Expand Down Expand Up @@ -431,6 +432,23 @@ def test_fp6_e3m2_pack_unpack():
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)


def triton_to_mxfp8_dim0_reference(
x_hp: torch.Tensor, block_size
) -> tuple[torch.Tensor, torch.Tensor]:
"""
A reference version of `triton_to_mxfp8_dim0` for rowwise quantization.
"""
from torchao.prototype.mx_formats.mx_tensor import to_mx

# cast across dim0 (rowwise) - no transpose needed
scale_e8m0_dim0, x_hp_d0_normalized = to_mx(x_hp, torch.float8_e4m3fn, block_size)
scale_e8m0_dim0 = scale_e8m0_dim0.view(torch.float8_e8m0fnu)
return (
x_hp_d0_normalized,
scale_e8m0_dim0,
)


@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_89(),
Expand All @@ -446,6 +464,21 @@ def test_triton_mxfp8_dim1_randn(M, K):
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)


@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_89(),
reason="float8 in triton requires CUDA capability 8.9 or greater",
)
@pytest.mark.parametrize("M", (256, 2048, 131072))
@pytest.mark.parametrize("K", (256, 5120, 7168))
def test_triton_mxfp8_dim0_randn(M, K):
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"shape",
Expand Down
178 changes: 176 additions & 2 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,13 +891,13 @@ def _get_mxfp8_dim1_kernel_autotune_configs():

@triton.autotune(
configs=_get_mxfp8_dim1_kernel_autotune_configs(),
key=["n_rows", "n_cols", "INNER_BLOCK_SIZE"],
key=["n_cols", "INNER_BLOCK_SIZE"],
)
@triton.jit
def to_mxfp8_dim1_kernel(
x_ptr, # pointer to input tensor
output_col_major_ptr, # pointer to column-major output tensor (column-normalized)
col_scale_ptr, # pointer to store column-wise maximum absolute values
col_scale_ptr, # pointer to store scales
n_rows, # number of rows in the tensor
n_cols, # number of columns in the tensor
ROW_TILE_SIZE: tl.constexpr,
Expand Down Expand Up @@ -1038,6 +1038,174 @@ def to_mxfp8_dim1_kernel(
# TODO(future): mask this store
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)

@triton.autotune(
configs=_get_mxfp8_dim1_kernel_autotune_configs(),
key=["n_cols", "INNER_BLOCK_SIZE"],
)
@triton.jit
def to_mxfp8_dim0_kernel(
x_ptr, # pointer to input tensor
output_ptr, # pointer to output tensor (row-normalized)
row_scale_ptr, # pointer to store row-wise maximum absolute values
n_rows, # number of rows in the tensor
n_cols, # number of columns in the tensor
ROW_TILE_SIZE: tl.constexpr,
COL_TILE_SIZE: tl.constexpr,
INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX
):
"""
Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity).

This is the counterpart to to_mxfp8_dim1_kernel which does columnwise quantization.
Instead of transposing and scaling across columns, this kernel scales across rows.
"""

BLOCKS_PER_COL_TILE: tl.constexpr = COL_TILE_SIZE // INNER_BLOCK_SIZE

# Get program ID
pid_row = tl.program_id(0)
pid_col = tl.program_id(1)

# Calculate starting row and column for this tile
start_row = pid_row * ROW_TILE_SIZE
start_col = pid_col * COL_TILE_SIZE

# Create offsets for the block
row_offsets = tl.arange(0, ROW_TILE_SIZE)
col_offsets = tl.arange(0, COL_TILE_SIZE)

# Compute global row/col positions
rows = start_row + row_offsets[:, None]
cols = start_col + col_offsets[None, :]

# Create masks for out-of-bounds accesses
row_mask = rows < n_rows
col_mask = cols < n_cols
mask = row_mask & col_mask

# Compute memory offsets for row-major layout (rows, cols)
row_major_offsets = (rows * n_cols + cols).to(tl.int32)

# Load the entire block in a single operation
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE)
x_block = tl.load(x_ptr + row_major_offsets, mask=mask)

# Reshape to inner tile size for rowwise scaling
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
x_block_r = x_block.reshape(
ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE
)

# Calculate the absolute values of elements in the block
x_block_abs_r = tl.abs(x_block_r)

# Find the maximum absolute value for each row (across columns)
# shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
row_scale_r, row_scale_e8m0_r = _triton_calculate_scale(x_block_abs_r, axis=1)

# Divide each row by scale
# Broadcasting row_scale to match x_block's shape
# x_block_r shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
# row_scale shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1)
row_normalized_r = x_block_r / row_scale_r[:, None]

# Reshape back to original tile size
row_normalized = tl.reshape(row_normalized_r, ROW_TILE_SIZE, COL_TILE_SIZE)

# Quantize to float8
row_normalized = row_normalized.to(tl.float8e4nv)

# Store the row-normalized result in row-major format
tl.store(output_ptr + row_major_offsets, row_normalized, mask=mask)

# For rowwise quantization, scale tensor has shape (n_rows, n_cols // INNER_BLOCK_SIZE)
# Calculate base offset for this tile's scales
scales_per_row = n_cols // INNER_BLOCK_SIZE

# Create row and column indices for scale storage
scale_row_indices = tl.arange(0, ROW_TILE_SIZE)[:, None] + (
pid_row * ROW_TILE_SIZE
)
scale_col_indices = tl.arange(0, BLOCKS_PER_COL_TILE)[None, :] + (
pid_col * BLOCKS_PER_COL_TILE
)

# Calculate linear indices into scale tensor
scale_offsets = scale_row_indices * scales_per_row + scale_col_indices

# Create masks for valid scale indices
scale_row_mask = scale_row_indices < n_rows
scale_col_mask = scale_col_indices < scales_per_row
scale_mask = scale_row_mask & scale_col_mask

# Reshape scale values and masks to match the flattened layout
row_scale_e8m0_2d = row_scale_e8m0_r.reshape(ROW_TILE_SIZE, BLOCKS_PER_COL_TILE)

# Store the scales with proper masking
tl.store(row_scale_ptr + scale_offsets, row_scale_e8m0_2d, mask=scale_mask)

@triton_op("torchao::triton_to_mxfp8_dim0", mutates_args={})
def triton_to_mxfp8_dim0(
x: torch.Tensor, inner_block_size: int = 32
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Input:
* `x` - input tensor, in row major memory layout
* `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes

Output:
* `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise)
* `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
"""
assert x.is_contiguous(), "`x` must be contiguous"
assert inner_block_size <= 32

# Reshape tensor to 2d if necessary and get shape
x_orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
n_rows, n_cols = x.shape

assert n_cols % inner_block_size == 0, (
"columns must be divisible by inner block size"
)

# Create output tensors
output = torch.empty(
(n_rows, n_cols), dtype=torch.float8_e4m3fn, device=x.device
)

# Create scale tensors for rowwise scaling
row_scale = torch.empty(
(n_rows, n_cols // inner_block_size),
dtype=torch.uint8,
device=x.device,
)

# Calculate grid dimensions based on tile size
grid = lambda META: (
triton.cdiv(n_rows, META["ROW_TILE_SIZE"]),
triton.cdiv(n_cols, META["COL_TILE_SIZE"]),
)

# Launch the kernel
wrap_triton(to_mxfp8_dim0_kernel)[grid](
x_ptr=x,
output_ptr=output,
row_scale_ptr=row_scale,
n_rows=n_rows,
n_cols=n_cols,
INNER_BLOCK_SIZE=inner_block_size,
)

# Reshape output back to original shape
output = output.reshape(x_orig_shape)
row_scale = row_scale.reshape(*x_orig_shape[:-1], row_scale.shape[-1])

return (
output,
row_scale.view(torch.float8_e8m0fnu),
)

@triton_op("torchao::triton_to_mxfp8_dim1", mutates_args={})
def triton_to_mxfp8_dim1(
x: torch.Tensor, inner_block_size: int = 32
Expand Down Expand Up @@ -1459,6 +1627,12 @@ def _(scale_tensor):
return scale_tensor.new_empty((padded_rows, padded_cols))
else:

def triton_to_mxfp8_dim0(
x: torch.Tensor,
inner_block_size=32,
) -> Tuple[torch.Tensor, torch.Tensor]:
raise AssertionError("needs torch version 2.8+ and triton")

def triton_to_mxfp8_dim1(
x, inner_block_size=32
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down
Loading