diff --git a/testing/python/autotune/test_tilelang_autotune.py b/testing/python/autotune/test_tilelang_autotune.py new file mode 100644 index 000000000..813f3176d --- /dev/null +++ b/testing/python/autotune/test_tilelang_autotune.py @@ -0,0 +1,297 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import itertools +import logging + +import tilelang as tl +import tilelang.testing +import tilelang.language as T +from tilelang.autotuner import autotune, jit + +# Configure logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def ref_program(A, B): + """ + A reference matrix multiplication program, used to compare performance. + + Parameters + ---------- + A : numpy.ndarray + The matrix with shape (M, K). + B : numpy.ndarray + The matrix with shape (N, K). + + Returns + ------- + np.ndarray + The result of A @ B.T, shape (M, N). + """ + return A @ B.T + + +def get_configs(M, N, K, with_roller=False): + """ + Generate a list of configuration dictionaries that will be used for tuning. + + Parameters + ---------- + with_roller : bool + Whether to enable bitblas roller to deduce search spaces + + Returns + ------- + list of dict + Each configuration dict includes various block sizes, pipeline stages, + thread numbers, and other parameters to explore during autotuning. + """ + if with_roller: + from bitblas.base.utils import get_roller_hints_from_func + from bitblas.ops.general_matmul.tirscript import matmul_select_implementation + from bitblas.base.arch import CUDA + from bitblas.base.roller.rasterization import NoRasterization + arch = CUDA("cuda") + topk = 20 + + # Simple TIR Compute Expression + ir_module = matmul_select_implementation( + M=M, + N=N, + K=K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + ) + + roller_hints = get_roller_hints_from_func( + ir_module, + arch, + topk, + tensorcore_only=True, + allow_gemv=True, + ) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + configs = [] + for hint in roller_hints: + config = {} + block_m, block_n = hint.block + warp_m, warp_n = hint.warp + config["block_M"] = block_m + config["block_N"] = block_n + config["block_K"] = hint.rstep[0] + config["num_stages"] = 0 + config["thread_num"] = (block_m * block_n) // (warp_m * warp_n) * 32 + config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization + configs.append(config) + for config in configs: + print(config) + else: + + block_M = [64] + block_N = [64] + block_K = [32] + num_stages = [0, 1] + thread_num = [128] + enable_rasterization = [False] + + _configs = list( + itertools.product( + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasterization, + )) + + configs = [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5], # keep param name for backward-compat + } for c in _configs + ] + return configs + + +def matmul(M, N, K, with_roller): + """ + Create an autotuned matrix multiplication kernel for matrices of shape: + - A: (M, K) + - B: (N, K) + - C: (M, N) + + Parameters + ---------- + M : int + The dimension M of the matrix multiplication. + N : int + The dimension N of the matrix multiplication. + K : int + The dimension K of the matrix multiplication. + + Returns + ------- + (best_latency, best_config, ref_latency) + best_latency : float + The best latency found among the tuned configurations. + best_config : dict + The parameter configuration that yielded best_latency. + ref_latency : float + The baseline latency of the reference program (for computing speedup). + """ + + # Decorate the kernel with autotune & jit, specifying: + # - Tuning config list + # - Profiling keys + # - Warmup and repetition counts for better measurement + # - A reference program for correctness verification + # - The "tvm" profiler backend + # - HIP as the compilation target (modify as needed for your hardware) + if with_roller: + # check out bitblas is installed + try: + import bitblas # noqa: F401 + except ImportError as e: + raise ImportError( + "BitBlas is not installed. Please install it via 'pip install bitblas'.") from e + + @autotune( + configs=get_configs(M, N, K, with_roller), + keys=[ + "block_M", + "block_N", + "block_K", + "num_stages", + "thread_num", + "enable_rasteration", + ], + warmup=3, + rep=5, + ) + @jit( + out_idx=[2], + supply_type=tl.TensorSupplyType.Integer, + ref_prog=ref_program, + skip_check=True, + profiler="auto", + target="auto", + ) + def kernel( + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + enable_rasteration=None, + ): + """ + The actual kernel to compute C = A @ B^T. + + Parameters + ---------- + block_M : int + Block size in M dimension. + block_N : int + Block size in N dimension. + block_K : int + Block size in K dimension. + num_stages : int + Number of pipelined stages (for asynchronous load). + thread_num : int + Number of threads to use per block. + enable_rasteration : bool + Whether to enable rasterization (swizzling) optimization. + k_pack : int + K dimension packing factor to improve memory coalescing. + + Returns + ------- + Function + A TVM Tensor Language function (T.prim_func) that computes matmul. + """ + # Use half-precision for input data to reduce memory bandwidth, + # accumulate in float for better numerical accuracy + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def main( + A: T.Buffer((M, K), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((M, N), dtype), + ): + """ + The compiled TVM function for block-level matrix multiplication. + + - We divide the entire (M, N) domain into blocks of shape + (block_M, block_N). + - Each block has its own allocated shared memory for sub-blocks + of A and B. + - The partial results go into C_local, and then we copy them back + to global memory C. + """ + # Bind x-dimension to block index in N, + # y-dimension to block index in M. + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + + # Allocate shared memory for A sub-block of shape (block_M, block_K) + A_shared = T.alloc_shared((block_M, block_K), dtype) + # Allocate shared memory for B sub-block of shape (block_N, block_K) + B_shared = T.alloc_shared((block_N, block_K), dtype) + # Allocate a local fragment for intermediate accumulation + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable (or disable) swizzling optimization + T.use_swizzle(panel_size=10, enable=enable_rasteration) + + # Clear out the accumulation buffer + T.clear(C_local) + + # Loop over sub-blocks in K dimension, pipelined by num_stages + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + # Load a sub-block of A from global memory into A_shared + T.copy( + A[by * block_M, k * block_K], + A_shared, + ) + # Load a sub-block of B from global memory into B_shared + T.copy( + B[bx * block_N, k * block_K], + B_shared, + ) + # Perform a partial matrix multiplication: + # C_local += A_shared @ B_shared^T + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + ) + # Write back the results from C_local to the global memory C + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + return kernel() + + +def test_autotune_get_configs(): + get_configs(8192, 8192, 8192, with_roller=False) + + +def test_autotune_matmul(): + matmul(8192, 8192, 8192, with_roller=False) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/jit/test_tilelang_jit_gemm.py b/testing/python/jit/test_tilelang_jit_gemm.py index 1963bdaae..ec7baacd0 100644 --- a/testing/python/jit/test_tilelang_jit_gemm.py +++ b/testing/python/jit/test_tilelang_jit_gemm.py @@ -127,6 +127,123 @@ def test_gemm_f16f16f16_nn(): ) +def matmu_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + import tilelang.language as T + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_jit_kernel( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmu_jit_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dl_pack") + + A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda() + B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda() + + if trans_A: + A = A.T + if trans_B: + B = B.T + + def ref_program(A, B): + import torch + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + ref_C = ref_program(A, B) + C = matmul_kernel(A, B) + + tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) + + +def test_gemm_jit_kernel(): + run_gemm_jit_kernel( + 512, + 1024, + 768, + False, + False, + "float16", + "float16", + "float16", + 128, + 256, + 32, + 2, + ) + + if __name__ == "__main__": - # tilelang.testing.main() - test_gemm_f16f16f16_nn() + tilelang.testing.main() diff --git a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py new file mode 100644 index 000000000..fc0c13201 --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing + + +def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2): + dtype = "float32" + + @T.prim_func + def main(A: T.Buffer((M, N), dtype="float32"),): + with T.Kernel(1, 1, threads=M) as (bx, by): + A_shared = T.alloc_shared((M, N), dtype=dtype) + tid = T.get_thread_binding() + for j in T.serial(N): + A_shared[tid, j] = A[tid + M_offset, j + N_offset] + + @T.prim_func + def expected(A: T.Buffer((M, N), dtype="float32"),): + with T.Kernel(1, 1, threads=M) as (bx, by): + A_shared = T.alloc_shared((M, N), dtype=dtype) + tid = T.get_thread_binding() + + T.reads(A[tid + M_offset, N_offset:N + N_offset]) + for j in T.serial(N): + A_shared[tid, j] = T.if_then_else( + j + N_offset < N, + T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], + T.float32(0)), T.float32(0)) + + return main, expected + + +def assert_vectorize_access(M: int = 64, N: int = 64): + func, expected = vectorize_access_legalize(M, N) + mod = tvm.IRModule({func.attrs["global_symbol"]: func}) + transformed = tl.transform.LegalizeSafeMemoryAccess()(mod) + tvm.ir.assert_structural_equal(transformed["main"].body, expected.body) + + +def test_vectorize_access(): + assert_vectorize_access(64, 64) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py b/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py new file mode 100644 index 000000000..9e6702333 --- /dev/null +++ b/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py @@ -0,0 +1,47 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from tilelang import tvm as tvm +import tilelang as tl +import tilelang.language as T +import tilelang.testing + + +def vectorize_access_legalize(M: int = 64, N: int = 64): + dtype = "float32" + vec_len = 8 + + @T.prim_func + def main(A: T.Buffer((M, N, vec_len), dtype="float32"),): + with T.Kernel(1, 1, threads=M) as (bx, by): + A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype) + tid = T.get_thread_binding() + for j in T.serial(N): + for v in T.vectorized(vec_len): + A_shared[tid, j, v] = A[tid, j, v] + + @T.prim_func + def expected(A: T.Buffer((M, N, vec_len), dtype="float32"),): + with T.Kernel(1, 1, threads=M) as (bx, by): + A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype) + tid = T.get_thread_binding() + for j, v_2 in T.grid(M, vec_len // 4): + for vec in T.vectorized(4): + A_shared[tid, j, v_2 * 4 + vec] = A[tid, j, v_2 * 4 + vec] + + return main, expected + + +def assert_vectorize_access(M: int = 64, N: int = 64): + func, expected = vectorize_access_legalize(M, N) + mod = tvm.IRModule({func.attrs["global_symbol"]: func}) + transformed = tl.transform.LegalizeVectorizedLoop()(mod) + tvm.ir.assert_structural_equal(transformed["main"].body, expected.body) + + +def test_vectorize_access(): + assert_vectorize_access(64, 64) + + +if __name__ == "__main__": + tilelang.testing.main()