From 4288406fb0ae9bb1592f3aa390201b1827e21ef3 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Thu, 7 Aug 2025 12:55:21 +0000 Subject: [PATCH 1/8] [MXFP4] Dequantize FP4 kernel example, MX scale todo --- .../example_dequant_gemm_mxfp4_hopper.py | 303 ++++++++++++++++++ 1 file changed, 303 insertions(+) create mode 100644 examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py diff --git a/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py new file mode 100644 index 000000000..64ab8ff73 --- /dev/null +++ b/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py @@ -0,0 +1,303 @@ +import tilelang +import tilelang.language as T +from tvm import tir +import argparse +import torch + +tilelang.disable_cache() + +torch.manual_seed(0) + +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == "bfloat16" + assert val.dtype == "uint8" + mask = tir.const((1 << nbit) - 1, "uint16") + f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask + s = f4 >> tir.const(3, "uint16") + e_f4 = (f4 & tir.const(6, "uint16")) >> 1 + # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 + e_bf16 = e_f4 + tir.const(126, "uint16") + m_f4 = f4 & tir.const(1, "uint16") + val_bf16 = tir.reinterpret( + "bfloat16", ((((s << 8) | e_bf16) << 7) | (m_f4 << tir.const(6, "uint16"))).astype("uint16") + ) + return val_bf16 + +def torch_convert(tensor): + + def print_bit(name, val): + val_cpu = val.cpu().item() + binary_repr = f'{val_cpu:032b}' + print(name, binary_repr) + + def _convert(val, pos): + assert val.dtype == torch.uint8 + # val = val.view(torch.int8) + mask = (1 << 4) - 1 + f4 = ((val >> (pos * 4)) & mask).to(torch.int16) + s = f4 >> 3 + e_f4 = (f4 & 6) >> 1 + e_f16 = e_f4 + 126 + m_f4 = f4 & 1 + m_f16 = m_f4 + val_f16 = (((e_f16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF + lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16) + return lower_16_bits.view(torch.bfloat16) + + N = tensor.shape[0] + K = tensor.shape[1] + new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device) + for i in range(new_tensor.shape[0]): + for j in range(new_tensor.shape[1]): + new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) + return new_tensor + + +@tilelang.jit(out_idx=[1]) +def convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): + num_elems_per_byte = 8 // num_bits + storage_dtype = "uint8" + B_shape = (N, K // num_elems_per_byte) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + + @T.prim_func + def main( + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=in_dtype, + ) + T.copy(B_dequantize_local, C[bx * block_N, k * block_K]) + + return main + + +def test_fp4_bf16_convert_close(): + N, K = 256, 256 + block_N, block_K = 64, 64 + kernel = convert( + N, + K, + block_N, + block_K, + "bfloat16", + ) + + B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8) + tl_out = kernel(B) + ref_out = torch_convert(B) + assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out) + print("Convert Pass") + + +def get_configs(): + block_M = [128] + block_N = [128, 256] + block_K = [128] + num_stages = [2] + threads = [256] + splits = [1] + _configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits)) + + configs = [{ + 'block_M': c[0], + 'block_N': c[1], + 'block_K': c[2], + 'num_stages': c[3], + 'threads': c[4], + 'split': c[5] + } for c in _configs] + return configs + + +def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): + + @tilelang.jit(out_idx=[2]) + def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): + num_elems_per_byte = 8 // num_bits + storage_dtype = "uint8" + A_shape = (M, K) + B_shape = (N, K // num_elems_per_byte) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + assert K % (block_K * split) == 0 + KK = K // split + + @T.prim_func + def main_split( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), + ): + SplitC = T.alloc_buffer([ + split, (N + block_N - 1) // block_N * block_N, + (M + block_M - 1) // block_M * block_M + ], out_dtype) + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, + threads=threads) as (bx, by, bz): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) + Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) + + T.annotate_layout({ + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + }) + + T.clear(Ct_local) + for k in T.Pipelined(K // (block_K * split), num_stages=num_stages): + T.copy(A[by * block_M, KK * bz + k * block_K], A_shared) + T.copy(B[bx * block_N, (KK * bz + k * block_K) // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=in_dtype, + ) + T.copy(B_dequantize_local, B_dequantize_prev_local) + T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) + T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N, + by * block_M:(by + 1) * block_M]) + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by): + acc = T.alloc_fragment((block_N, block_M), out_dtype) + T.clear(acc) + for k in range(split): + for i, j in T.Parallel(block_N, block_M): + acc[i, j] += SplitC[k, bx * block_N + i, by * block_M + j] + T.copy(acc, Ct[bx * block_N, by * block_M]) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) + Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) + + T.annotate_layout({ + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + }) + + T.clear(Ct_local) + for k in T.Pipelined(K // block_K, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=in_dtype, + ) + T.copy(B_dequantize_local, B_dequantize_prev_local) + T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) + T.copy(Ct_local, Ct_shared) + T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N, + by * block_M:(by + 1) * block_M]) + + if split == 1: + return main + else: + return main_split + + if tune: + @autotune( + configs=get_configs(), + keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"], + warmup=10, + rep=10) + @tilelang.jit(out_idx=[2]) + def kernel(block_M=None, + block_N=None, + block_K=None, + num_stages=None, + threads=None, + split=None): + return kernel_func(block_M, block_N, block_K, num_stages, threads, split) + return kernel() + else: + def kernel(block_M, block_N, block_K, num_stages, threads, split=1): + return kernel_func(block_M, block_N, block_K, num_stages, threads, split) + return kernel + + +def ref_program(A, qB): + dtypeC = "bfloat16" + B = torch_convert(qB) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C.transpose(0, 1) + + +def main(m=256, n=256, k=256, tune=False): + total_flops = 2 * m * n * k + + if (not tune): + kernel = matmul( + m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, tune=tune)( + block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + best_result = matmul(m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, tune=tune) + best_latency = best_result.latency + best_config = best_result.config + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + + +def test_convert(): + test_fp4_bf16_convert_close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--m', type=int, default=256, help='M') + parser.add_argument('--n', type=int, default=256, help='N') + parser.add_argument('--k', type=int, default=256, help='K') + parser.add_argument('--tune', action='store_true', help='tune configs') + args = parser.parse_args() + M, N, K = args.m, args.n, args.k + main(M, N, K, args.tune) + # test_convert() From f502d08a198677564223457aba937951f077e87e Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Thu, 7 Aug 2025 13:13:54 +0000 Subject: [PATCH 2/8] [BugFix] Fix the bug of fp4&fp16 exponential bias --- .../example_dequant_gemm_fp4_hopper.py | 22 ++++++++++++------- .../example_dequant_gemm_mxfp4_hopper.py | 4 ++-- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py index 668f58a96..64d99f38b 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -6,22 +6,26 @@ import torch import argparse +tilelang.disable_cache() + def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert nbit == 4 assert dtype == "float16" assert val.dtype == "uint8" # e_f4 == 0 -> e_f16 = 0 - # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2 - # s1e2n1 + # e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14 + # s1e2m1 mask = tir.const((1 << nbit) - 1, "uint16") f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask s = f4 >> tir.const(3, "uint16") - e_f4 = f4 & tir.const(7, "uint16") - e_f16 = e_f4 | tir.const(8, "uint16") + e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") + e_f16 = e_f4 + tir.const(14, "uint16") + m_f4 = f4 & tir.const(1, "uint16") + m_f16 = m_f4 val_f16 = tir.reinterpret( "float16", - ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16")) + ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16") | m_f16 << tir.const(9, "uint16")).astype("uint16")) # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) return val_f16 @@ -39,9 +43,11 @@ def _convert(val, pos): mask = (1 << 4) - 1 f4 = ((val >> (pos * 4)) & mask).to(torch.int16) s = f4 >> 3 - e_f4 = f4 & 7 - e_f16 = e_f4 | 8 - val_f16 = ((e_f16 | (s << 5)) << 10) & 0xFFFF + e_f4 = (f4 & 6) >> 1 + e_f16 = e_f4 + 14 + m_f4 = f4 & 1 + m_f16 = m_f4 + val_f16 = (((e_f16 | (s << 5)) << 10) | (m_f16 << 9)) & 0xFFFF lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16) return lower_16_bits.view(torch.float16) diff --git a/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py index 64ab8ff73..2a5b8b096 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py @@ -15,12 +15,12 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype mask = tir.const((1 << nbit) - 1, "uint16") f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask s = f4 >> tir.const(3, "uint16") - e_f4 = (f4 & tir.const(6, "uint16")) >> 1 + e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 e_bf16 = e_f4 + tir.const(126, "uint16") m_f4 = f4 & tir.const(1, "uint16") val_bf16 = tir.reinterpret( - "bfloat16", ((((s << 8) | e_bf16) << 7) | (m_f4 << tir.const(6, "uint16"))).astype("uint16") + "bfloat16", ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16") ) return val_bf16 From 544c6bc722cbbea9c7eea4f84830599717a5a930 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Fri, 8 Aug 2025 02:53:08 +0000 Subject: [PATCH 3/8] [MXFP4] Add group scale factor for BF16xMXFP4 gemm --- .../example_dequant_gemm_mxfp4_hopper.py | 120 ++++++++++++++++-- 1 file changed, 106 insertions(+), 14 deletions(-) diff --git a/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py index 2a5b8b096..7c395b438 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py @@ -2,13 +2,15 @@ import tilelang.language as T from tvm import tir import argparse +import itertools import torch tilelang.disable_cache() torch.manual_seed(0) -def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): assert nbit == 4 assert dtype == "bfloat16" assert val.dtype == "uint8" @@ -18,20 +20,24 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 e_bf16 = e_f4 + tir.const(126, "uint16") + # Scale is the exponential part, within the representation of uint8 + # To handle the overflow, we use the max function to limit the exponential part to 8 bits + e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) m_f4 = f4 & tir.const(1, "uint16") val_bf16 = tir.reinterpret( "bfloat16", ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16") ) return val_bf16 -def torch_convert(tensor): + +def torch_convert(tensor, scale_size=None, Scale=None): def print_bit(name, val): val_cpu = val.cpu().item() binary_repr = f'{val_cpu:032b}' print(name, binary_repr) - def _convert(val, pos): + def _convert(val, pos, scale=None): assert val.dtype == torch.uint8 # val = val.view(torch.int8) mask = (1 << 4) - 1 @@ -39,6 +45,8 @@ def _convert(val, pos): s = f4 >> 3 e_f4 = (f4 & 6) >> 1 e_f16 = e_f4 + 126 + if scale is not None: + e_f16 = min(e_f16 + scale, (1 << 8) - 1) m_f4 = f4 & 1 m_f16 = m_f4 val_f16 = (((e_f16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF @@ -50,11 +58,14 @@ def _convert(val, pos): new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device) for i in range(new_tensor.shape[0]): for j in range(new_tensor.shape[1]): - new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) + if scale_size is not None: + new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2, Scale[i][j // scale_size]) + else: + new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) return new_tensor -@tilelang.jit(out_idx=[1]) +@tilelang.jit(out_idx=[-1]) def convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): num_elems_per_byte = 8 // num_bits storage_dtype = "uint8" @@ -80,6 +91,48 @@ def main( num_bits, B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, + 0, # No scale for test + dtype=in_dtype, + ) + T.copy(B_dequantize_local, C[bx * block_N, k * block_K]) + + return main + + +@tilelang.jit(out_idx=[-1]) +def convert_scale(N, K, block_N, block_K, in_dtype, num_bits=4, scale_size=32, threads=128): + num_elems_per_byte = 8 // num_bits + storage_dtype = "uint8" + B_shape = (N, K // num_elems_per_byte) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + Scale_shape = (N, K // scale_size) + Scale_shared_shape = (block_N, block_K // scale_size) + + @T.prim_func + def main( + B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + Scale_shared = T.alloc_shared(Scale_shared_shape, storage_dtype) + Scale_local = T.alloc_fragment(Scale_shared_shape, storage_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + T.copy(Scale[bx * block_N, k * block_K // scale_size], Scale_shared) + T.copy(Scale_shared, Scale_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + Scale_local[i, j // scale_size], # Scale is the exponential part, within the representation of uint8 dtype=in_dtype, ) T.copy(B_dequantize_local, C[bx * block_N, k * block_K]) @@ -105,6 +158,20 @@ def test_fp4_bf16_convert_close(): print("Convert Pass") +def test_fp4_bf16_convert_scale_close(): + N, K = 256, 256 + block_N, block_K = 64, 64 + kernel = convert_scale( + N, K, block_N, block_K, "bfloat16", scale_size=32) + + B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8) + Scale = torch.randint(0, 1, (N, K // 32), dtype=torch.uint8, device="cuda").to(torch.uint8) + tl_out = kernel(B, Scale) + ref_out = torch_convert(B, scale_size=32, Scale=Scale) + assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out) + print("Convert Scale Pass") + + def get_configs(): block_M = [128] block_N = [128, 256] @@ -125,17 +192,19 @@ def get_configs(): return configs -def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): +def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, scale_size=32, tune=False): - @tilelang.jit(out_idx=[2]) + @tilelang.jit(out_idx=[-1]) def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): num_elems_per_byte = 8 // num_bits storage_dtype = "uint8" A_shape = (M, K) B_shape = (N, K // num_elems_per_byte) + Scale_shape = (N, K // scale_size) A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K // num_elems_per_byte) B_dequantize_shared_shape = (block_N, block_K) + Scale_shared_shape = (block_N, block_K // scale_size) assert K % (block_K * split) == 0 KK = K // split @@ -143,6 +212,7 @@ def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): def main_split( A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), Ct: T.Tensor((N, M), out_dtype), ): SplitC = T.alloc_buffer([ @@ -159,6 +229,8 @@ def main_split( B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) + Scale_shared = T.alloc_shared(Scale_shared_shape, storage_dtype) + Scale_local = T.alloc_fragment(Scale_shared_shape, storage_dtype) T.annotate_layout({ B_shared: tilelang.layout.make_swizzled_layout(B_shared), @@ -170,11 +242,14 @@ def main_split( T.copy(A[by * block_M, KK * bz + k * block_K], A_shared) T.copy(B[bx * block_N, (KK * bz + k * block_K) // num_elems_per_byte], B_shared) T.copy(B_shared, B_local) + T.copy(Scale[bx * block_N, (KK * bz + k * block_K) // scale_size], Scale_shared) + T.copy(Scale_shared, Scale_local) for i, j in T.Parallel(block_N, block_K): B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( num_bits, B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, + Scale_local[i, j // scale_size], dtype=in_dtype, ) T.copy(B_dequantize_local, B_dequantize_prev_local) @@ -193,6 +268,7 @@ def main_split( def main( A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), Ct: T.Tensor((N, M), out_dtype), ): with T.Kernel( @@ -204,10 +280,13 @@ def main( B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) + Scale_shared = T.alloc_shared((block_N, block_M // scale_size), storage_dtype) + Scale_local = T.alloc_fragment((block_N, block_M // scale_size), storage_dtype) T.annotate_layout({ B_shared: tilelang.layout.make_swizzled_layout(B_shared), Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + Scale_shared: tilelang.layout.make_swizzled_layout(Scale_shared), }) T.clear(Ct_local) @@ -215,11 +294,14 @@ def main( T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) T.copy(B_shared, B_local) + T.copy(Scale[bx * block_N, k * block_K // scale_size], Scale_shared) + T.copy(Scale_shared, Scale_local) for i, j in T.Parallel(block_N, block_K): B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( num_bits, B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, + Scale_local[i, j // scale_size], dtype=in_dtype, ) T.copy(B_dequantize_local, B_dequantize_prev_local) @@ -239,7 +321,7 @@ def main( keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"], warmup=10, rep=10) - @tilelang.jit(out_idx=[2]) + @tilelang.jit(out_idx=[-1]) def kernel(block_M=None, block_N=None, block_K=None, @@ -262,24 +344,32 @@ def ref_program(A, qB): return C.transpose(0, 1) -def main(m=256, n=256, k=256, tune=False): +def ref_program_scale(A, qB, Scale): + dtypeC = "bfloat16" + B = torch_convert(qB, scale_size=32, Scale=Scale) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C.transpose(0, 1) + + +def main(m=256, n=256, k=256, scale_size=32, tune=False): total_flops = 2 * m * n * k if (not tune): kernel = matmul( - m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, tune=tune)( + m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, tune=tune)( block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) - profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + profiler.assert_allclose(ref_program_scale, rtol=0.01, atol=0.01) print("All checks pass.") - latency = profiler.do_bench(ref_program, warmup=500) + latency = profiler.do_bench(ref_program_scale, warmup=500) print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) latency = profiler.do_bench(warmup=500) print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) else: - best_result = matmul(m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, tune=tune) + best_result = matmul(m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, tune=tune) best_latency = best_result.latency best_config = best_result.config print(f"Best latency: {best_latency}") @@ -289,6 +379,7 @@ def main(m=256, n=256, k=256, tune=False): def test_convert(): test_fp4_bf16_convert_close() + test_fp4_bf16_convert_scale_close() if __name__ == "__main__": @@ -296,8 +387,9 @@ def test_convert(): parser.add_argument('--m', type=int, default=256, help='M') parser.add_argument('--n', type=int, default=256, help='N') parser.add_argument('--k', type=int, default=256, help='K') + parser.add_argument('--scale_size', type=int, default=32, help='scale size, the exponential part, within the representation of uint8') parser.add_argument('--tune', action='store_true', help='tune configs') args = parser.parse_args() M, N, K = args.m, args.n, args.k - main(M, N, K, args.tune) # test_convert() + main(M, N, K, args.scale_size, args.tune) From 5962eb6aab0c3f35a53f1432c4c354a5cb27f2ba Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Fri, 8 Aug 2025 02:55:54 +0000 Subject: [PATCH 4/8] [Lint] --- .../example_dequant_gemm_fp4_hopper.py | 6 +- .../example_dequant_gemm_mxfp4_hopper.py | 60 ++++++++++++++----- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py index 64d99f38b..c291f7ece 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -23,9 +23,9 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: e_f16 = e_f4 + tir.const(14, "uint16") m_f4 = f4 & tir.const(1, "uint16") m_f16 = m_f4 - val_f16 = tir.reinterpret( - "float16", - ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16") | m_f16 << tir.const(9, "uint16")).astype("uint16")) + val_f16 = tir.reinterpret("float16", + ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16") + | m_f16 << tir.const(9, "uint16")).astype("uint16")) # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) return val_f16 diff --git a/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py index 7c395b438..46ea29a12 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py @@ -1,5 +1,6 @@ import tilelang import tilelang.language as T +from tilelang.autotuner import * from tvm import tir import argparse import itertools @@ -10,7 +11,8 @@ torch.manual_seed(0) -def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, + dtype: str): assert nbit == 4 assert dtype == "bfloat16" assert val.dtype == "uint8" @@ -23,10 +25,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale # Scale is the exponential part, within the representation of uint8 # To handle the overflow, we use the max function to limit the exponential part to 8 bits e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) - m_f4 = f4 & tir.const(1, "uint16") - val_bf16 = tir.reinterpret( - "bfloat16", ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16") - ) + m_f4 = f4 & tir.const(1, "uint16") + val_bf16 = tir.reinterpret("bfloat16", + ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) + | (m_f4 << tir.const(6, "uint16"))).astype("uint16")) return val_bf16 @@ -91,7 +93,7 @@ def main( num_bits, B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, - 0, # No scale for test + 0, # No scale for test dtype=in_dtype, ) T.copy(B_dequantize_local, C[bx * block_N, k * block_K]) @@ -111,9 +113,9 @@ def convert_scale(N, K, block_N, block_K, in_dtype, num_bits=4, scale_size=32, t @T.prim_func def main( - B: T.Tensor(B_shape, storage_dtype), - Scale: T.Tensor(Scale_shape, storage_dtype), - C: T.Tensor((N, K), in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), ): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): B_shared = T.alloc_shared(B_shared_shape, storage_dtype) @@ -121,7 +123,7 @@ def main( B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) Scale_shared = T.alloc_shared(Scale_shared_shape, storage_dtype) Scale_local = T.alloc_fragment(Scale_shared_shape, storage_dtype) - + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) T.copy(B_shared, B_local) @@ -132,7 +134,9 @@ def main( num_bits, B_local[i, j // num_elems_per_byte], j % num_elems_per_byte, - Scale_local[i, j // scale_size], # Scale is the exponential part, within the representation of uint8 + Scale_local[ + i, j // + scale_size], # Scale is the exponential part, within the representation of uint8 dtype=in_dtype, ) T.copy(B_dequantize_local, C[bx * block_N, k * block_K]) @@ -161,8 +165,7 @@ def test_fp4_bf16_convert_close(): def test_fp4_bf16_convert_scale_close(): N, K = 256, 256 block_N, block_K = 64, 64 - kernel = convert_scale( - N, K, block_N, block_K, "bfloat16", scale_size=32) + kernel = convert_scale(N, K, block_N, block_K, "bfloat16", scale_size=32) B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8) Scale = torch.randint(0, 1, (N, K // 32), dtype=torch.uint8, device="cuda").to(torch.uint8) @@ -316,6 +319,7 @@ def main( return main_split if tune: + @autotune( configs=get_configs(), keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"], @@ -329,10 +333,13 @@ def kernel(block_M=None, threads=None, split=None): return kernel_func(block_M, block_N, block_K, num_stages, threads, split) + return kernel() else: + def kernel(block_M, block_N, block_K, num_stages, threads, split=1): return kernel_func(block_M, block_N, block_K, num_stages, threads, split) + return kernel @@ -357,7 +364,15 @@ def main(m=256, n=256, k=256, scale_size=32, tune=False): if (not tune): kernel = matmul( - m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, tune=tune)( + m, + n, + k, + "bfloat16", + "bfloat16", + "float32", + num_bits=4, + scale_size=scale_size, + tune=tune)( block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) profiler.assert_allclose(ref_program_scale, rtol=0.01, atol=0.01) @@ -369,7 +384,16 @@ def main(m=256, n=256, k=256, scale_size=32, tune=False): print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) else: - best_result = matmul(m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, tune=tune) + best_result = matmul( + m, + n, + k, + "bfloat16", + "bfloat16", + "float32", + num_bits=4, + scale_size=scale_size, + tune=tune) best_latency = best_result.latency best_config = best_result.config print(f"Best latency: {best_latency}") @@ -387,7 +411,11 @@ def test_convert(): parser.add_argument('--m', type=int, default=256, help='M') parser.add_argument('--n', type=int, default=256, help='N') parser.add_argument('--k', type=int, default=256, help='K') - parser.add_argument('--scale_size', type=int, default=32, help='scale size, the exponential part, within the representation of uint8') + parser.add_argument( + '--scale_size', + type=int, + default=32, + help='scale size, the exponential part, within the representation of uint8') parser.add_argument('--tune', action='store_true', help='tune configs') args = parser.parse_args() M, N, K = args.m, args.n, args.k From 8df24e9fb16b9310dc5ef940c23f232a4e7b11ae Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Fri, 8 Aug 2025 03:10:35 +0000 Subject: [PATCH 5/8] [Test] Add test script for BF16xMXFP4 gemm --- examples/dequantize_gemm/test_example_dequantize_gemm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/dequantize_gemm/test_example_dequantize_gemm.py b/examples/dequantize_gemm/test_example_dequantize_gemm.py index e662cbd66..63d6d3eca 100644 --- a/examples/dequantize_gemm/test_example_dequantize_gemm.py +++ b/examples/dequantize_gemm/test_example_dequantize_gemm.py @@ -2,6 +2,7 @@ import example_dequant_gemv_fp16xint4 import example_dequant_gemm_fp4_hopper +import example_dequant_gemm_mxfp4_hopper @tilelang.testing.requires_cuda @@ -15,5 +16,10 @@ def test_example_dequant_gemm_fp4_hopper(): example_dequant_gemm_fp4_hopper.main() +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_dequant_gemm_mxfp4_hopper(): + example_dequant_gemm_mxfp4_hopper.main() + if __name__ == "__main__": tilelang.testing.main() From 6047af73fcc80d73a696e03f2bf58ef79dcc1e56 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Fri, 8 Aug 2025 03:10:59 +0000 Subject: [PATCH 6/8] [Lint] --- examples/dequantize_gemm/test_example_dequantize_gemm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dequantize_gemm/test_example_dequantize_gemm.py b/examples/dequantize_gemm/test_example_dequantize_gemm.py index 63d6d3eca..6f66c799e 100644 --- a/examples/dequantize_gemm/test_example_dequantize_gemm.py +++ b/examples/dequantize_gemm/test_example_dequantize_gemm.py @@ -21,5 +21,6 @@ def test_example_dequant_gemm_fp4_hopper(): def test_example_dequant_gemm_mxfp4_hopper(): example_dequant_gemm_mxfp4_hopper.main() + if __name__ == "__main__": tilelang.testing.main() From 4f2a4ceecf36e211880151b2301dc0034e0d6201 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Fri, 8 Aug 2025 03:20:00 +0000 Subject: [PATCH 7/8] [BugFix] Fix the shape of scale tensor --- .../dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py index 46ea29a12..bc318a860 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper.py @@ -238,6 +238,7 @@ def main_split( T.annotate_layout({ B_shared: tilelang.layout.make_swizzled_layout(B_shared), Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + Scale_shared: tilelang.layout.make_swizzled_layout(Scale_shared), }) T.clear(Ct_local) @@ -283,8 +284,8 @@ def main( B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) - Scale_shared = T.alloc_shared((block_N, block_M // scale_size), storage_dtype) - Scale_local = T.alloc_fragment((block_N, block_M // scale_size), storage_dtype) + Scale_shared = T.alloc_shared((block_N, block_K // scale_size), storage_dtype) + Scale_local = T.alloc_fragment((block_N, block_K // scale_size), storage_dtype) T.annotate_layout({ B_shared: tilelang.layout.make_swizzled_layout(B_shared), From 55a58ec4f261ab59e94b1266e70b8b3145bea417 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Sun, 10 Aug 2025 21:31:46 +0800 Subject: [PATCH 8/8] Update example_dequant_gemm_fp4_hopper.py --- examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py index c291f7ece..f36f02908 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -6,8 +6,6 @@ import torch import argparse -tilelang.disable_cache() - def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert nbit == 4