Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 159 additions & 54 deletions aiter/ops/triton/_triton_kernels/fused_mxfp4_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,97 +10,202 @@ def _rmsmorm_op(row, weight, n_cols, epsilon):
row_norm = tl.sum(row_norm, axis=-1)
norm_factor = tl.math.rsqrt((row_norm / n_cols) + epsilon)

rms_norm = row * norm_factor * weight
rms_norm = row * norm_factor[:, None] * weight
return rms_norm


@triton.heuristics(
{
"EVEN_M_N": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0
and args["N1"] % (args["BLOCK_SIZE_N"]) == 0,
"EVEN_M_N2": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0
and args["N2"] % (args["BLOCK_SIZE_N2"]) == 0,
}
)
@triton.jit
def _fused_rms_mxfp4_quant_kernel(
inp1_ptr,
weight1_ptr,
inp2_ptr,
weight2_ptr,
x1_ptr,
w1_ptr,
x2_ptr,
w2_ptr,
res1_ptr,
out1_fp4_ptr,
out1_bs_ptr,
out2_ptr,
out_res1_ptr,
eps1,
eps2,
n_rows,
inp1_n_cols,
inp2_n_cols,
inp1_row_stride,
inp2_row_stride,
res1_row_stride,
out1_fp4_row_stride,
out1_bs_row_stride,
out1_bs_col_stride,
out2_row_stride,
out_res1_row_stride,
BLOCK_SIZE: tl.constexpr,
M,
N1,
N2,
x1_stride_m,
x2_stride_m,
res1_stride_m,
out1_fp4_stride_m,
out1_bs_stride_m,
out1_bs_stride_n,
out2_stride_m,
out_res1_stride_m,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_N2: tl.constexpr,
MXFP4_QUANT_BLOCK_SIZE: tl.constexpr,
SKIP_SECOND_INPUT: tl.constexpr,
HAS_SECOND_INPUT: tl.constexpr,
FIRST_INPUT_RES: tl.constexpr,
SCALE_N: tl.constexpr,
SCALE_M_PAD: tl.constexpr,
SCALE_N_PAD: tl.constexpr,
SHUFFLE: tl.constexpr,
SHUFFLE_PAD: tl.constexpr,
EVEN_M_N: tl.constexpr,
EVEN_M_N2: tl.constexpr,
):
# TODO: XCD remapping where every 32-token block should share the same XCD
# TODO: debug for large M
# TODO: investigate cache_modifier='.cg' on tl.store
pid = tl.program_id(0)
NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE // MXFP4_QUANT_BLOCK_SIZE
block_inds = tl.arange(0, BLOCK_SIZE)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)

if pid >= num_pid_m:
if HAS_SECOND_INPUT:
pid -= num_pid_m
x_offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
x_offs_n2 = tl.arange(0, BLOCK_SIZE_N2)
mask2 = None
other2 = None
if not EVEN_M_N2:
mask2 = (x_offs_m < M)[:, None] & (x_offs_n2 < N2)[None, :]
other2 = 0.0

x2 = tl.load(
x2_ptr + x_offs_m[:, None] * x2_stride_m + x_offs_n2[None, :],
mask=mask2,
other=other2,
cache_modifier=".cg",
).to(tl.float32)

w_mask2 = None
w_other2 = None
if not EVEN_M_N2:
w_mask2 = x_offs_n2 < N2
w_other2 = 0.0

w2 = tl.load(w2_ptr + x_offs_n2, mask=w_mask2, other=w_other2).to(
tl.float32
)

norm2 = _rmsmorm_op(x2, w2, N2, eps2)

mask1 = block_inds < inp1_n_cols
inp1 = tl.load(
inp1_ptr + pid * inp1_row_stride + block_inds,
tl.store(
out2_ptr + x_offs_m[:, None] * out2_stride_m + x_offs_n2[None, :],
norm2.to(out2_ptr.type.element_ty),
mask=mask2,
cache_modifier=".cg",
)
return

x_offs_n = tl.arange(0, BLOCK_SIZE_N)
NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE
x_offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)

mask1 = None
other1 = None
if not EVEN_M_N:
mask1 = (x_offs_m < M)[:, None] & (x_offs_n < N1)[None, :]
other1 = 0.0

x1 = tl.load(
x1_ptr + x_offs_m[:, None] * x1_stride_m + x_offs_n[None, :],
mask=mask1,
other=0.0,
other=other1,
cache_modifier=".cg",
).to(tl.float32)

if FIRST_INPUT_RES:
res1 = tl.load(
res1_ptr + pid * res1_row_stride + block_inds,
res1_ptr + x_offs_m[:, None] * res1_stride_m + x_offs_n[None, :],
mask=mask1,
other=0.0,
other=other1,
cache_modifier=".cg",
).to(tl.float32)
inp1 = inp1 + res1
x1 = x1 + res1

w_mask1 = None
w_other1 = None
if not EVEN_M_N:
w_mask1 = x_offs_n < N1
w_other1 = 0.0

w1 = tl.load(weight1_ptr + block_inds, mask=mask1, other=0.0).to(tl.float32)
w1 = tl.load(w1_ptr + x_offs_n, mask=w_mask1, other=w_other1).to(tl.float32)

norm1 = _rmsmorm_op(inp1, w1, inp1_n_cols, eps1)
out1_fp4, out1_block_scales = _mxfp4_quant_op(
norm1, BLOCK_SIZE, 1, MXFP4_QUANT_BLOCK_SIZE
norm1 = _rmsmorm_op(x1, w1, N1, eps1)
out1_fp4, bs_e8m0 = _mxfp4_quant_op(
norm1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE
)
out1_fp4 = tl.ravel(out1_fp4)
out1_block_scales = tl.ravel(out1_block_scales)

# store the results
half_block_inds = tl.arange(0, BLOCK_SIZE // 2)
half_x_offs_n = tl.arange(0, BLOCK_SIZE_N // 2)
out_mask1 = None
if not EVEN_M_N:
out_mask1 = (x_offs_m < M)[:, None] & (half_x_offs_n < (N1 // 2))[None, :]

tl.store(
out1_fp4_ptr + pid * out1_fp4_row_stride + half_block_inds,
out1_fp4_ptr + x_offs_m[:, None] * out1_fp4_stride_m + half_x_offs_n[None, :],
out1_fp4,
mask=half_block_inds < (inp1_n_cols // 2),
mask=out_mask1,
cache_modifier=".cg",
)
bs_inds = tl.arange(0, NUM_QUANT_BLOCKS)
num_bs_cols = (inp1_n_cols + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE

bs_offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
bs_offs_n = tl.arange(0, NUM_QUANT_BLOCKS)
num_bs_cols = (N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE
if SHUFFLE:
bs_offs_0 = bs_offs_m[:, None] // 32
bs_offs_1 = bs_offs_m[:, None] % 32
bs_offs_2 = bs_offs_1 % 16
bs_offs_1 = bs_offs_1 // 16
bs_offs_3 = bs_offs_n[None, :] // 8
bs_offs_4 = bs_offs_n[None, :] % 8
bs_offs_5 = bs_offs_4 % 4
bs_offs_4 = bs_offs_4 // 4
bs_offs = (
bs_offs_1
+ bs_offs_4 * 2
+ bs_offs_2 * 2 * 2
+ bs_offs_5 * 2 * 2 * 16
+ bs_offs_3 * 2 * 2 * 16 * 4
+ bs_offs_0 * 2 * 16 * SCALE_N_PAD
)
bs_mask_127 = (bs_offs_m < M)[:, None] & (bs_offs_n < num_bs_cols)[None, :]
bs_e8m0 = tl.where(bs_mask_127, bs_e8m0, 127)
else:
bs_offs = (
bs_offs_m[:, None] * out1_bs_stride_m
+ bs_offs_n[None, :] * out1_bs_stride_n
)

bs_mask = None
if not EVEN_M_N:
if SHUFFLE_PAD:
bs_mask = (bs_offs_m < SCALE_M_PAD)[:, None] & (bs_offs_n < SCALE_N_PAD)[
None, :
]
else:
bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < SCALE_N)[None, :]

tl.store(
out1_bs_ptr + pid * out1_bs_row_stride + bs_inds * out1_bs_col_stride,
out1_block_scales,
mask=bs_inds < num_bs_cols,
out1_bs_ptr + bs_offs,
bs_e8m0.to(out1_bs_ptr.type.element_ty),
mask=bs_mask,
cache_modifier=".cg",
)
if not SKIP_SECOND_INPUT:
mask2 = block_inds < inp2_n_cols
inp2 = tl.load(
inp2_ptr + pid * inp2_row_stride + block_inds,
mask=mask2,
other=0.0,
cache_modifier=".cg",
).to(tl.float32)
w2 = tl.load(weight2_ptr + block_inds, mask=mask2, other=0.0).to(tl.float32)
norm2 = _rmsmorm_op(inp2, w2, inp2_n_cols, eps2)
tl.store(out2_ptr + pid * out2_row_stride + block_inds, norm2, mask=mask2)

if FIRST_INPUT_RES:
inp1 = inp1.to(out_res1_ptr.dtype.element_ty)
tl.store(
out_res1_ptr + pid * out_res1_row_stride + block_inds, inp1, mask=mask1
out_res1_ptr + x_offs_m[:, None] * out_res1_stride_m + x_offs_n[None, :],
x1.to(out_res1_ptr.dtype.element_ty),
mask=mask1,
cache_modifier=".cg",
)


Expand Down
21 changes: 9 additions & 12 deletions aiter/ops/triton/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def act_mul_and_mxfp4_quant(
activation: Literal["silu", "gelu", "gelu_tanh"],
scaling_mode: str = "even",
shuffle: bool = False,
scale_shuffle_padding: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply the activation function and quantize the result to MX FP4 format.
Expand Down Expand Up @@ -53,22 +54,18 @@ def act_mul_and_mxfp4_quant(
x_fp4 = torch.empty((M, N_half // 2), dtype=torch.uint8, device=x.device)
scaleN_valid = triton.cdiv(N_half, MXFP4_QUANT_BLOCK_SIZE)
# Setting scale M to be multiple of 256 and scale N to be multiple of 8
if shuffle:
use_scale_shuffle_padding = shuffle or scale_shuffle_padding
if use_scale_shuffle_padding:
scaleM = triton.cdiv(M, 256) * 256
scaleN = triton.cdiv(scaleN_valid, 8) * 8
blockscale_e8m0 = torch.empty(
(scaleM, scaleN),
dtype=torch.uint8,
device=x.device,
)
else:
scaleM = M
scaleN = scaleN_valid
blockscale_e8m0 = torch.empty(
(scaleN, scaleM),
dtype=torch.uint8,
device=x.device,
).T
blockscale_e8m0 = torch.empty(
(scaleM, scaleN),
dtype=torch.uint8,
device=x.device,
)

# for large N values
if M <= 32:
Expand Down Expand Up @@ -116,7 +113,7 @@ def act_mul_and_mxfp4_quant(
SCALING_MODE=0,
ACTIVATION=activation,
scaleN=scaleN_valid,
scaleM_pad=scaleM,
scaleM_pad=(scaleM if use_scale_shuffle_padding else 1),
scaleN_pad=scaleN,
SHUFFLE=shuffle,
NUM_ITER=NUM_ITER,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
{
"small": {
"BLOCK_SIZE_M": 8,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 6,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 8
},
"small_M16": {
"BLOCK_SIZE_M": 8,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 1024,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 4
},
"medium_M32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 8
},
"medium_M64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 8
},
"medium_M128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 1024,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 4
},
"large": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 1024,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 4
},
"xlarge": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 1024,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 4
}
}
Loading