Skip to content

Commit 50e789d

Browse files
authored
[Feature] Support None type as input for T.ptr and T.Tensor (#1114)
* [Feature] Support None type as input for T.ptr and T.Tensor * lint * lint * lint * lint fix
1 parent a148d62 commit 50e789d

File tree

3 files changed

+120
-2
lines changed

3 files changed

+120
-2
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import torch
2+
from tilelang import tvm as tvm
3+
import tilelang.testing
4+
import tilelang as tl
5+
import tilelang.language as T
6+
from tilelang.utils import map_torch_type
7+
8+
9+
@tl.jit
10+
def ptr_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
11+
12+
@T.prim_func
13+
def main(
14+
a_ptr: T.ptr,
15+
b_ptr: T.ptr,
16+
c_ptr: T.ptr,
17+
bias_ptr: T.ptr,
18+
m: T.int32,
19+
n: T.int32,
20+
k: T.int32,
21+
with_bias: T.bool,
22+
):
23+
A = T.make_tensor(a_ptr, (m, k), dtype)
24+
B = T.make_tensor(b_ptr, (k, n), dtype)
25+
C = T.make_tensor(c_ptr, (m, n), accum_dtype)
26+
Bias = T.make_tensor(bias_ptr, (n), accum_dtype)
27+
28+
# Initialize Kernel Context
29+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
30+
A_shared = T.alloc_shared((block_M, block_K), dtype)
31+
B_shared = T.alloc_shared((block_N, block_K), dtype)
32+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
33+
34+
T.clear(C_local)
35+
36+
for ko in T.Pipelined(T.ceildiv(k, block_K), num_stages=3):
37+
# Copy tile of A
38+
T.copy(A[by * block_M, ko * block_K], A_shared)
39+
T.copy(B[bx * block_N, ko * block_K], B_shared)
40+
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
41+
42+
if with_bias:
43+
for i, j in T.Parallel(block_M, block_N):
44+
C_local[i, j] += Bias[bx * block_N + j]
45+
46+
T.copy(C_local, C[by * block_M, bx * block_N])
47+
48+
return main
49+
50+
51+
@tl.jit
52+
def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
53+
54+
@T.prim_func
55+
def main(
56+
A: T.Tensor((M, K), dtype),
57+
B: T.Tensor((K, N), dtype),
58+
C: T.Tensor((M, N), accum_dtype),
59+
Bias: T.Tensor((N), accum_dtype),
60+
with_bias: T.bool,
61+
):
62+
# Initialize Kernel Context
63+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
64+
A_shared = T.alloc_shared((block_M, block_K), dtype)
65+
B_shared = T.alloc_shared((block_N, block_K), dtype)
66+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
67+
68+
T.clear(C_local)
69+
70+
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
71+
# Copy tile of A
72+
T.copy(A[by * block_M, ko * block_K], A_shared)
73+
T.copy(B[bx * block_N, ko * block_K], B_shared)
74+
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
75+
76+
if with_bias:
77+
for i, j in T.Parallel(block_M, block_N):
78+
C_local[i, j] += Bias[bx * block_N + j]
79+
80+
T.copy(C_local, C[by * block_M, bx * block_N])
81+
82+
return main
83+
84+
85+
def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
86+
func = ptr_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
87+
88+
a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype))
89+
b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype))
90+
c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype))
91+
d = torch.randn(N, device="cuda", dtype=map_torch_type(accum_dtype))
92+
93+
func(a, b, c, None, M, N, K, False)
94+
95+
ref_no_bias = (a @ b.T).to(map_torch_type(accum_dtype))
96+
ref_with_bias = ref_no_bias + d
97+
98+
torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2)
99+
100+
func(a, b, c, d, M, N, K, True)
101+
102+
torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2)
103+
104+
func = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
105+
func(a, b, c, None, False)
106+
torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2)
107+
func(a, b, c, d, True)
108+
torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2)
109+
110+
111+
def test_nullptr():
112+
run_test(1024, 1024, 1024, 128, 128, 32)
113+
114+
115+
if __name__ == "__main__":
116+
tilelang.testing.main()

tilelang/jit/adapter/cython/cython_wrapper.pyx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ cdef class CythonKernelWrapper:
251251
if dtype not in dtype_to_ctype:
252252
raise ValueError(f"Unsupported tensor dtype: {dtype}")
253253
call_args.append(dtype_to_ctype[dtype](tensor))
254+
elif tensor is None:
255+
call_args.append(ctypes.c_void_p(0))
254256
else:
255257
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
256258

tilelang/language/allocate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
with the appropriate memory scope.
1515
"""
1616

17+
from __future__ import annotations
1718
from tilelang import tvm as tvm
1819
from tvm.script import tir as T
1920
from tvm.tir import PrimExpr
2021
from tvm.script.parser.tir import block_attr
21-
from typing import Union
2222

2323

2424
def alloc_shared(shape, dtype, scope="shared.dyn"):
@@ -67,7 +67,7 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
6767
return T.alloc_buffer(shape, dtype, scope=scope)
6868

6969

70-
def alloc_var(dtype, *args, scope="local.var", init: Union[PrimExpr] = None):
70+
def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
7171
"""Allocate a single-element variable buffer.
7272
7373
Args:

0 commit comments

Comments
 (0)