Skip to content

Commit 2b3bd54

Browse files
committed
lint fix
1 parent 0794c29 commit 2b3bd54

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

examples/quickstart.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
# @tilelang.jit(target="cuda")
66
# target currently can be "cuda" or "hip" or "cpu".
77
# if not specified, it will be inferred from the input tensors during compile time
8-
@tilelang.jit(execution_backend="tvm_ffi", pass_configs={
9-
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER:True,
10-
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
11-
})
8+
@tilelang.jit
129
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
1310

1411
@T.prim_func
@@ -51,7 +48,7 @@ def matmul_relu_kernel(
5148
return matmul_relu_kernel
5249

5350

54-
M = T.dynamic("m") # M = T.dynamic("m") if you want to use dynamic shape
51+
M = 1024 # M = T.dynamic("m") if you want to use dynamic shape
5552
N = 1024
5653
K = 1024
5754
block_M = 128
@@ -60,11 +57,10 @@ def matmul_relu_kernel(
6057

6158
# 1. Define the kernel (matmul) and compile/lower it into an executable module
6259
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
63-
60+
print(matmul_relu_kernel.get_kernel_source())
6461
# 3. Test the kernel in Python with PyTorch data
6562
import torch
6663

67-
M = 0
6864
# Create random input tensors on the GPU
6965
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
7066
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
@@ -81,3 +77,13 @@ def matmul_relu_kernel(
8177
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
8278
print("Kernel output matches PyTorch reference.")
8379

80+
# 4. Retrieve and inspect the generated CUDA source (optional)
81+
# cuda_source = jit_kernel.get_kernel_source()
82+
# print("Generated CUDA kernel:\n", cuda_source)
83+
84+
# 5.Profile latency with kernel
85+
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
86+
87+
latency = profiler.do_bench()
88+
89+
print(f"Latency: {latency} ms")

tilelang/jit/adapter/tvm_ffi.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int]]:
122122
dynamic_symbolic_map[stride] = (1, i, j)
123123
return dynamic_symbolic_map
124124

125-
126125
def _convert_torch_func(self) -> Callable[..., Any]:
127126
# Capture thunks that reflect Torch's current stream and device.
128127
# These are evaluated at call time to align TVM execution with the
@@ -264,7 +263,6 @@ def func(*inputs: torch.Tensor | Any):
264263

265264
return func
266265

267-
268266
@classmethod
269267
def from_database(cls,
270268
params: list[TensorType],

0 commit comments

Comments
 (0)