Skip to content

Commit 778dbd2

Browse files
authored
[Feature] Add CTypes JIT kernel support (#100)
* [Feature] Add CTypes JIT kernel support for dynamic shapes and multi-stream execution - Enhance CtypesKernelAdapter to handle dynamic symbolic shapes - Add support for multi-stream kernel execution in CTypes backend - Implement dynamic shape handling in test_tilelang_jit_gemm_ctypes.py - Add symbolic shape utility function in tilelang.language - Update profiler to improve flexibility in benchmark selection * Remove redundant thread binding in GEMM kernel implementations - Remove unnecessary `thread_binding` line in GEMM kernel functions - Clean up code in `examples/gemm/README.md` and `testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py` - Enhance code readability by removing redundant thread binding annotation * Fix indentation in int4 GEMM kernel test file - Correct indentation for function calls in `test_tilelang_kernel_int4_gemm_mma.py` - Remove extra indentation in `mma_emitter.ldmatrix_a()` and `mma_emitter.ldmatrix_b()` calls - Improve code formatting for better readability
1 parent 15b926a commit 778dbd2

File tree

8 files changed

+326
-253
lines changed

8 files changed

+326
-253
lines changed

examples/gemm/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,6 @@ def tl_matmul(
339339
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
340340
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
341341

342-
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
343-
344342
T.annotate_layout({
345343
A_shared: make_swizzle_layout(A_shared),
346344
B_shared: make_swizzle_layout(B_shared),

testing/python/jit/test_tilelang_jit_gemm_ctypes.py

Lines changed: 167 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Licensed under the MIT License.
33

44
from tilelang import tvm as tvm
5+
import tilelang.language as T
56
import tilelang.testing
67
import tilelang
78
import torch
@@ -27,8 +28,6 @@ def matmul(
2728
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
2829
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
2930

30-
import tilelang.language as T
31-
3231
@T.prim_func
3332
def main(
3433
A: T.Buffer(A_shape, in_dtype),
@@ -235,5 +234,171 @@ def test_gemm_jit_kernel():
235234
)
236235

237236

237+
def run_ctypes_kernel_do_bench(M,
238+
N,
239+
K,
240+
trans_A,
241+
trans_B,
242+
in_dtype,
243+
out_dtype,
244+
dtypeAccum,
245+
block_M,
246+
block_N,
247+
block_K,
248+
num_stages=3,
249+
num_threads=128):
250+
program = matmul(
251+
M,
252+
N,
253+
K,
254+
block_M,
255+
block_N,
256+
block_K,
257+
trans_A,
258+
trans_B,
259+
in_dtype,
260+
out_dtype,
261+
dtypeAccum,
262+
num_stages,
263+
num_threads,
264+
)
265+
266+
matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
267+
268+
profiler = matmul_kernel.get_profiler()
269+
270+
ctypes_latency = profiler.do_bench(func=matmul_kernel, profiler="torch")
271+
print(f"Ctypes Latency: {ctypes_latency} ms")
272+
273+
assert ctypes_latency is not None
274+
275+
tvm_latency = profiler.do_bench()
276+
print(f"TVM Latency: {tvm_latency} ms")
277+
278+
assert tvm_latency is not None
279+
280+
281+
def test_ctypes_kernel_do_bench():
282+
run_ctypes_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128,
283+
256, 32, 2)
284+
285+
286+
def run_ctypes_kernel_multi_stream(M,
287+
N,
288+
K,
289+
trans_A,
290+
trans_B,
291+
in_dtype,
292+
out_dtype,
293+
dtypeAccum,
294+
block_M,
295+
block_N,
296+
block_K,
297+
num_stages=3,
298+
num_threads=128):
299+
program = matmul(
300+
M,
301+
N,
302+
K,
303+
block_M,
304+
block_N,
305+
block_K,
306+
trans_A,
307+
trans_B,
308+
in_dtype,
309+
out_dtype,
310+
dtypeAccum,
311+
num_stages,
312+
num_threads,
313+
)
314+
315+
matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
316+
317+
tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
318+
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
319+
320+
if trans_A:
321+
tensor_a = tensor_a.T
322+
if trans_B:
323+
tensor_b = tensor_b.T
324+
tensor_c = torch.randn(M, N, dtype=torch.__getattribute__(out_dtype)).cuda()
325+
326+
num_streams = 4
327+
for _ in range(num_streams):
328+
stream = torch.cuda.Stream()
329+
with torch.cuda.stream(stream):
330+
matmul_kernel(tensor_a, tensor_b, tensor_c)
331+
332+
333+
def test_ctypes_kernel_multi_stream():
334+
run_ctypes_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16",
335+
128, 256, 32, 2)
336+
337+
338+
def run_ctypes_dynamic_shape(M,
339+
N,
340+
K,
341+
trans_A,
342+
trans_B,
343+
in_dtype,
344+
out_dtype,
345+
dtypeAccum,
346+
block_M,
347+
block_N,
348+
block_K,
349+
num_stages=3,
350+
num_threads=128):
351+
program = matmul(
352+
M,
353+
N,
354+
K,
355+
block_M,
356+
block_N,
357+
block_K,
358+
trans_A,
359+
trans_B,
360+
in_dtype,
361+
out_dtype,
362+
dtypeAccum,
363+
num_stages,
364+
num_threads,
365+
)
366+
367+
matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
368+
if isinstance(M, T.Var):
369+
M = 1024
370+
if isinstance(N, T.Var):
371+
N = 1024
372+
if isinstance(K, T.Var):
373+
K = 768
374+
tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
375+
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
376+
377+
if trans_A:
378+
tensor_a = tensor_a.T
379+
if trans_B:
380+
tensor_b = tensor_b.T
381+
tensor_c = torch.randn(M, N, dtype=torch.__getattribute__(out_dtype)).cuda()
382+
383+
matmul_kernel(tensor_a, tensor_b, tensor_c)
384+
385+
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float))
386+
tilelang.testing.torch_assert_close(
387+
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
388+
389+
390+
def test_ctypes_dynamic_shape():
391+
run_ctypes_dynamic_shape(
392+
T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
393+
394+
run_ctypes_dynamic_shape(
395+
T.symbolic("m"), T.symbolic("n"), 768, False, False, "float16", "float16", "float16", 128,
396+
256, 32, 2)
397+
398+
run_ctypes_dynamic_shape(
399+
T.symbolic("m"), T.symbolic("n"), T.symbolic("k"), False, False, "float16", "float16",
400+
"float16", 128, 256, 32, 2)
401+
402+
238403
if __name__ == "__main__":
239404
tilelang.testing.main()

testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,6 @@ def main(
109109
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
110110
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
111111

112-
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
113-
114112
T.annotate_layout({
115113
A_shared: make_swizzle_layout(A_shared),
116114
B_shared: make_swizzle_layout(B_shared),
@@ -138,14 +136,14 @@ def main(
138136
A_local,
139137
A_shared,
140138
ki,
141-
)
139+
)
142140

143141
# Load B into fragment
144142
mma_emitter.ldmatrix_b(
145143
B_local,
146144
B_shared,
147145
ki,
148-
)
146+
)
149147

150148
# Perform Matrix Multiplication
151149
mma_emitter.mma(A_local, B_local, C_local)
@@ -294,8 +292,6 @@ def main(
294292
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
295293
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
296294

297-
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
298-
299295
T.annotate_layout({
300296
A_shared: make_swizzle_layout(A_shared),
301297
B_shared: make_swizzle_layout(B_shared),
@@ -325,14 +321,14 @@ def main(
325321
A_local,
326322
A_shared,
327323
ki,
328-
)
324+
)
329325

330326
# Load B into fragment
331327
mma_emitter.ldmatrix_b(
332328
B_local,
333329
B_shared,
334330
ki,
335-
)
331+
)
336332

337333
# Perform Matrix Multiplication
338334
mma_emitter.mma(A_local, B_local, C_local)

0 commit comments

Comments
 (0)