Skip to content

Commit f2f8a28

Browse files
committed
minor fix
1 parent 65448b5 commit f2f8a28

File tree

8 files changed

+12
-37
lines changed

8 files changed

+12
-37
lines changed

examples/blocksparse_gemm/example_blocksparse_gemm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ def main():
166166
enable_rasteration=DEFAULT_ENABLE_RASTERIZATION)
167167
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
168168
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
169-
print(kernel.get_kernel_source())
170169
# Create block mask with desired sparsity
171170
mask_shape = (M // block_M, N // block_N, K // block_K)
172171
block_mask = torch.rand(mask_shape).cuda() > sparsity
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import tilelang.testing
12
import example_blocksparse_gemm
23

34

@@ -6,5 +7,4 @@ def test_example_blocksparse_gemm():
67

78

89
if __name__ == "__main__":
9-
# tilelang.testing.main()
10-
test_example_blocksparse_gemm()
10+
tilelang.testing.main()

examples/gdn/test_example_gdn_compilation.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def test_example_chunk_o_compilation():
107107

108108

109109
def test_example_chunk_o_bwd_compilation():
110-
tilelang.disable_cache()
111110
from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input
112111
Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size,
113112
getattr(torch, input_dtype),
@@ -118,13 +117,6 @@ def test_example_chunk_o_bwd_compilation():
118117
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
119118
gate_dtype, state_dtype, chunk_size, 1.0, use_g, True,
120119
block_DK, block_DV, threads, num_stages)
121-
# print(kernel.get_kernel_source())
122-
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
123-
gate_dtype, state_dtype, chunk_size, 1.0, use_g, True,
124-
block_DK, block_DV, threads, num_stages)
125-
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
126-
gate_dtype, state_dtype, chunk_size, 1.0, use_g, True,
127-
block_DK, block_DV, threads, num_stages)
128120

129121
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv,
130122
W) # noqa: F841
@@ -197,5 +189,4 @@ def test_example_chunk_delta_bwd_compilation():
197189

198190

199191
if __name__ == "__main__":
200-
# tilelang.testing.main()
201-
test_example_chunk_o_bwd_compilation()
192+
tilelang.testing.main()

examples/quickstart.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
# @tilelang.jit(target="cuda")
66
# target currently can be "cuda" or "hip" or "cpu".
77
# if not specified, it will be inferred from the input tensors during compile time
8-
@tilelang.jit
8+
@tilelang.jit(execution_backend="tvm_ffi", pass_configs={
9+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER:True,
10+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
11+
})
912
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
1013

1114
@T.prim_func
@@ -48,7 +51,7 @@ def matmul_relu_kernel(
4851
return matmul_relu_kernel
4952

5053

51-
M = 1024 # M = T.dynamic("m") if you want to use dynamic shape
54+
M = T.dynamic("m") # M = T.dynamic("m") if you want to use dynamic shape
5255
N = 1024
5356
K = 1024
5457
block_M = 128
@@ -61,6 +64,7 @@ def matmul_relu_kernel(
6164
# 3. Test the kernel in Python with PyTorch data
6265
import torch
6366

67+
M = 0
6468
# Create random input tensors on the GPU
6569
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
6670
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
@@ -77,13 +81,3 @@ def matmul_relu_kernel(
7781
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
7882
print("Kernel output matches PyTorch reference.")
7983

80-
# 4. Retrieve and inspect the generated CUDA source (optional)
81-
# cuda_source = jit_kernel.get_kernel_source()
82-
# print("Generated CUDA kernel:\n", cuda_source)
83-
84-
# 5.Profile latency with kernel
85-
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
86-
87-
latency = profiler.do_bench()
88-
89-
print(f"Latency: {latency} ms")

testing/python/debug/test_tilelang_debug_print.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ def program(Q: T.Tensor((M, N), dtype)):
1313
shared_buf = T.alloc_shared([M, N], dtype)
1414
T.print(shared_buf)
1515

16-
tilelang.disable_cache()
1716
jit_kernel = tilelang.compile(program, target="cuda", execution_backend="tvm_ffi")
1817
profiler = jit_kernel.get_profiler()
1918
profiler.run_once()

testing/python/jit/test_tilelang_jit_nullptr.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,12 @@ def main(
8383

8484

8585
def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
86-
tilelang.disable_cache()
8786
kernel = ptr_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
8887

8988
a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype))
9089
b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype))
9190
c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype))
9291
d = torch.randn(N, device="cuda", dtype=map_torch_type(accum_dtype))
93-
print(kernel.get_host_source())
9492
kernel(a, b, c, None, M, N, K, False)
9593

9694
ref_no_bias = (a @ b.T).to(map_torch_type(accum_dtype))
@@ -114,5 +112,4 @@ def test_nullptr():
114112

115113

116114
if __name__ == "__main__":
117-
# tilelang.testing.main()
118-
run_test(1024, 1024, 1024, 128, 128, 32)
115+
tilelang.testing.main()

testing/python/jit/test_tilelang_jit_nvrtc.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,7 @@ def run_nvrtc_dynamic_shape(M,
364364
num_threads,
365365
)
366366

367-
matmul_kernel = tilelang.compile(program, execution_backend="tvm_ffi")
368-
print(matmul_kernel.get_host_source())
367+
matmul_kernel = tilelang.compile(program, execution_backend="nvrtc")
369368
if isinstance(M, T.Var):
370369
M = 1024
371370
if isinstance(N, T.Var):
@@ -583,7 +582,4 @@ def kernel(
583582

584583

585584
if __name__ == "__main__":
586-
# tilelang.testing.main()
587-
tilelang.disable_cache()
588-
run_nvrtc_dynamic_shape(
589-
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
585+
tilelang.testing.main()

testing/python/language/test_tilelang_language_alloc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ def run_alloc_var_with_initializer(
113113

114114
kernel = tilelang.compile(program, out_idx=[1])
115115
code = kernel.get_kernel_source()
116-
print(code)
117116
assert f"= {init_value};" in code
118117

119118

0 commit comments

Comments
 (0)