55# @tilelang.jit(target="cuda")
66# target currently can be "cuda" or "hip" or "cpu".
77# if not specified, it will be inferred from the input tensors during compile time
8- @tilelang .jit (execution_backend = "tvm_ffi" , pass_configs = {
9- tilelang .PassConfigKey .TL_DISABLE_TMA_LOWER :True ,
10- tilelang .PassConfigKey .TL_DISABLE_WARP_SPECIALIZED : True ,
11- })
8+ @tilelang .jit
129def matmul (M , N , K , block_M , block_N , block_K , dtype = "float16" , accum_dtype = "float" ):
1310
1411 @T .prim_func
@@ -51,7 +48,7 @@ def matmul_relu_kernel(
5148 return matmul_relu_kernel
5249
5350
54- M = T . dynamic ( "m" ) # M = T.dynamic("m") if you want to use dynamic shape
51+ M = 1024 # M = T.dynamic("m") if you want to use dynamic shape
5552N = 1024
5653K = 1024
5754block_M = 128
@@ -60,11 +57,10 @@ def matmul_relu_kernel(
6057
6158# 1. Define the kernel (matmul) and compile/lower it into an executable module
6259matmul_relu_kernel = matmul (M , N , K , block_M , block_N , block_K )
63-
60+ print ( matmul_relu_kernel . get_kernel_source ())
6461# 3. Test the kernel in Python with PyTorch data
6562import torch
6663
67- M = 0
6864# Create random input tensors on the GPU
6965a = torch .randn (M , K , device = "cuda" , dtype = torch .float16 )
7066b = torch .randn (K , N , device = "cuda" , dtype = torch .float16 )
@@ -81,3 +77,13 @@ def matmul_relu_kernel(
8177torch .testing .assert_close (c , ref_c , rtol = 1e-2 , atol = 1e-2 )
8278print ("Kernel output matches PyTorch reference." )
8379
80+ # 4. Retrieve and inspect the generated CUDA source (optional)
81+ # cuda_source = jit_kernel.get_kernel_source()
82+ # print("Generated CUDA kernel:\n", cuda_source)
83+
84+ # 5.Profile latency with kernel
85+ profiler = matmul_relu_kernel .get_profiler (tensor_supply_type = tilelang .TensorSupplyType .Normal )
86+
87+ latency = profiler .do_bench ()
88+
89+ print (f"Latency: { latency } ms" )
0 commit comments