From 5eb1a784a133bfdbaeb94c8908a40c285fd5b849 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 28 May 2025 07:40:04 +0000 Subject: [PATCH 1/2] Refactor convolution example to streamline configuration and remove unused code * Updated the `check_hopper` function to properly check for CUDA availability and compute capability. * Removed the `get_configs` and `get_best_config` functions, simplifying the example by eliminating unused autotuning logic. * Adjusted argument parsing in the `main` function to directly compile the convolution kernel without autotuning options. * Cleaned up the code for better readability and maintainability. --- examples/convolution/example_convolution.py | 170 ++------------------ 1 file changed, 9 insertions(+), 161 deletions(-) diff --git a/examples/convolution/example_convolution.py b/examples/convolution/example_convolution.py index b4d286eda..e05ced8bf 100644 --- a/examples/convolution/example_convolution.py +++ b/examples/convolution/example_convolution.py @@ -4,77 +4,15 @@ import tilelang from tilelang.autotuner import * import tilelang.language as T -import itertools import argparse -from tilelang.carver.template import ConvTemplate -from tilelang.carver.arch import CUDA -from tilelang.carver.roller.rasterization import NoRasterization def check_hopper(): - # if not torch.cuda.is_available(): - # return None - # props = torch.cuda.get_device_properties(0) - # compute_capability = props.major, props.minor - # return compute_capability == (9, 0) - return False - - -def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False): - if with_roller: - arch = CUDA("cuda") - topk = 10 - carve_template = ConvTemplate( - N=N, - C=C, - H=H, - W=W, - F=F, - K=K, - S=S, - D=D, - P=P, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float", - ).with_arch(arch) - - func = carve_template.equivalent_function() - assert func is not None, "Function is None" - roller_hints = carve_template.recommend_hints(topk=topk) - if roller_hints is None: - raise ValueError("No Roller Hints Found for TensorCore Scheduling") - configs = [] - for hint in roller_hints: - config = {} - block_m, block_n = hint.block - warp_m, warp_n = hint.warp - block_rows, block_cols = block_m // warp_m, block_n // warp_n - config["block_M"] = block_m - config["block_N"] = block_n - config["block_K"] = hint.rstep[0] - config["num_stages"] = hint.pipeline_stage - config["thread_num"] = block_rows * block_cols * 32 - config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization - configs.append(config) - for config in configs: - print(config) - else: - block_M = [64, 128, 256] - block_N = [64, 128, 256] - block_K = [32, 64] - num_stages = [0, 1, 2, 3] - threads = [128, 256] - _configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads)) - - configs = [{ - 'block_M': c[0], - 'block_N': c[1], - 'block_K': c[2], - 'num_stages': c[3], - 'thread_num': c[4] - } for c in _configs] - return configs + if not torch.cuda.is_available(): + return None + props = torch.cuda.get_device_properties(0) + compute_capability = props.major, props.minor + return compute_capability == (9, 0) def ref_program(stride, padding, dilation): @@ -89,80 +27,6 @@ def main(A, B): return main -def get_best_config(N, C, H, W, F, K, S, D, P, with_roller): - KH, KW = K, K - OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 - OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 - - dtype = "float16" - accum_dtype = "float" - is_hopper = check_hopper() - - def kernel( - block_M=None, - block_N=None, - block_K=None, - num_stages=None, - thread_num=None, - ): - - @T.prim_func - def main( - data: T.Tensor((N, H, W, C), dtype), - kernel: T.Tensor((KH, KW, C, F), dtype), - out: T.Tensor((N, OH, OW, F), dtype), - ): - with T.Kernel( - T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), - threads=thread_num) as (bx, by): - data_shared = T.alloc_shared((block_M, block_K), dtype) - kernel_shared = T.alloc_shared((block_K, block_N), dtype) - out_local = T.alloc_fragment((block_M, block_N), accum_dtype) - out_shared = T.alloc_shared((block_M, block_N), dtype) - - kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) - out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) - - T.annotate_layout({ - out_shared: tilelang.layout.make_swizzled_layout(out_shared), - data_shared: tilelang.layout.make_swizzled_layout(data_shared), - kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), - }) - - T.clear(out_local) - for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): - if is_hopper: - T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P) - else: - for i, j in T.Parallel(block_M, block_K): - k = k_iter * block_K + j - m = by * block_M + i - access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P - access_w = m % OW * S + k // C % KW * D - P - in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and - (access_w < W)) - data_shared[i, j] = T.if_then_else( - in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) - T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) - T.gemm(data_shared, kernel_shared, out_local) - - T.copy(out_local, out_shared) - T.copy(out_shared, out_flat[by * block_M, bx * block_N]) - - return main - - autotuner = AutoTuner.from_kernel( - kernel=kernel, configs=get_configs(N, C, H, W, F, K, S, D, P, - with_roller)).set_compile_args( - out_idx=[2], - supply_type=tilelang.TensorSupplyType.Integer, - ref_prog=ref_program(S, P, D), - skip_check=False, - target="auto", - ) - return autotuner.run(warmup=10, rep=10) - - def convolution(N, C, H, @@ -243,30 +107,14 @@ def main(argv=None): parser.add_argument('--s', type=int, default=1, help='s') parser.add_argument('--d', type=int, default=1, help='d') parser.add_argument('--p', type=int, default=1, help='p') - parser.add_argument( - "--use_autotune", - action="store_true", - default=True, - help="Whether to use autotune for matmul configs") - parser.add_argument( - "--with_roller", - action="store_true", - default=True, - help="Whether to enable BitBLAS roller for search space") args = parser.parse_args(argv) N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p a = torch.randn(N, H, W, C).cuda().half() b = torch.randn(K, K, C, F).cuda().half() - use_autotune = args.use_autotune - with_roller = args.with_roller - if use_autotune: - result = get_best_config(N, C, H, W, F, K, S, D, P, with_roller) - print(f"best latency {result.latency}") - kernel = result.kernel - else: - kernel = tilelang.compile( - convolution(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256), out_idx=[2]) + + kernel = tilelang.compile( + convolution(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256), out_idx=[2]) out_c = kernel(a, b) ref_c = ref_program(S, P, D)(a, b) @@ -274,4 +122,4 @@ def main(argv=None): if __name__ == "__main__": - main() \ No newline at end of file + main() From afcb10ceffc4f99fe9ad4bd29631b38febffff95 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 28 May 2025 17:35:28 +0800 Subject: [PATCH 2/2] Update examples/convolution/example_convolution.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- examples/convolution/example_convolution.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/convolution/example_convolution.py b/examples/convolution/example_convolution.py index e05ced8bf..cb32d22ff 100644 --- a/examples/convolution/example_convolution.py +++ b/examples/convolution/example_convolution.py @@ -113,8 +113,14 @@ def main(argv=None): a = torch.randn(N, H, W, C).cuda().half() b = torch.randn(K, K, C, F).cuda().half() + block_m = 64 + block_n = 128 + block_k = 32 + num_stages = 3 + threads = 256 + kernel = tilelang.compile( - convolution(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256), out_idx=[2]) + convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads), out_idx=[2]) out_c = kernel(a, b) ref_c = ref_program(S, P, D)(a, b)