Skip to content

Commit a7e2027

Browse files
committed
Type consistency on tvm datatype
1. isinstance(tl.float32, tvm.DataType) == True 2. Allow `tl.float32` as function annotations 3. Allow `tl.float32` as argument to be passed to `tl.alloc` or other functions
1 parent 61bfbdd commit a7e2027

File tree

10 files changed

+860
-220
lines changed

10 files changed

+860
-220
lines changed

testing/python/jit/test_tilelang_jit_parcompile.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from tilelang import tvm
21
import tilelang.testing
32
import tilelang
43
import torch
54

5+
66
@tilelang.jit(
77
out_idx=-1, # create the output tensor during runtime
88
verbose=True,
@@ -61,7 +61,6 @@ def test_par_compile():
6161
(2048, 2048, 2048, 256, 256, 64),
6262
(4096, 4096, 4096, 64, 64, 128),
6363
]
64-
ker = matmul_kernel_jit(1024, 1024, 1024, 128, 128, 32)
6564
kernels = matmul_kernel_jit.par_compile(configs)
6665
for (M, N, K, _, _, _), kernel in zip(configs, kernels):
6766
A = torch.randn(M, K, dtype=torch.float16).cuda()
@@ -72,4 +71,4 @@ def test_par_compile():
7271

7372

7473
if __name__ == "__main__":
75-
tilelang.testing.main()
74+
tilelang.testing.main()
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import tilelang
2+
import tilelang.language as T
3+
import torch
4+
import tilelang.testing
5+
import tvm
6+
7+
def test_argument():
8+
@T.prim_func
9+
def test_argument(
10+
t_1: T.bool,
11+
t_2: T.short,
12+
t_3: T.int,
13+
t_4: T.long,
14+
t_5: T.half,
15+
t_6: T.float,
16+
t_7: T.long,
17+
t_8: T.int8,
18+
t_9: T.int16,
19+
t_10: T.int32,
20+
t_11: T.int64,
21+
t_12: T.uint8,
22+
t_13: T.uint16,
23+
t_14: T.uint32,
24+
t_15: T.uint64,
25+
t_16: T.float8_e4m3fn,
26+
t_17: T.float8_e4m3fnuz,
27+
t_18: T.float8_e5m2,
28+
t_19: T.float8_e5m2fnuz,
29+
t_20: T.float8_e8m0fnu,
30+
t_21: T.float16,
31+
t_22: T.bfloat16,
32+
t_23: T.float32,
33+
t_24: T.float64,
34+
):
35+
pass
36+
37+
38+
def test_expr():
39+
from tilelang.language.v2.dtypes import _all_dtypes
40+
errors = []
41+
for name in _all_dtypes:
42+
dtype = getattr(T, name)
43+
assert isinstance(dtype, tvm.DataType), f"{dtype} is not tvm.DataType"
44+
try:
45+
dtype(1.0)
46+
dtype()
47+
except TypeError as e:
48+
pass
49+
except Exception as e:
50+
errors.append(name)
51+
assert not errors
52+
53+
54+
def test_var_decl_sugar():
55+
@T.prim_func
56+
def test_var_decl_sugar():
57+
with T.Kernel(128, 128) as (bx, by):
58+
var_1: T.bool = 1.0
59+
var_2: T.short = 1.0
60+
var_3: T.int = 1.0
61+
var_4: T.long = 1.0
62+
var_5: T.half = 1.0
63+
var_6: T.float = 1.0
64+
var_7: T.long = 1.0
65+
var_8: T.int8 = 1.0
66+
var_9: T.int16 = 1.0
67+
var_10: T.int32 = 1.0
68+
var_11: T.int64 = 1.0
69+
var_12: T.uint8 = 1.0
70+
var_13: T.uint16 = 1.0
71+
var_14: T.uint32 = 1.0
72+
var_15: T.uint64 = 1.0
73+
var_16: T.float8_e4m3fn = 1.0
74+
var_17: T.float8_e4m3fnuz = 1.0
75+
var_18: T.float8_e5m2 = 1.0
76+
var_19: T.float8_e5m2fnuz = 1.0
77+
var_20: T.float8_e8m0fnu = 1.0
78+
var_21: T.float16 = 1.0
79+
var_22: T.bfloat16 = 1.0
80+
var_23: T.float32 = 1.0
81+
var_24: T.float64 = 1.0
82+
var_1: T.bool = var_1
83+
var_2: T.short = var_2
84+
var_3: T.int = var_3
85+
var_4: T.long = var_4
86+
var_5: T.half = var_5
87+
var_6: T.float = var_6
88+
var_7: T.long = var_7
89+
var_8: T.int8 = var_8
90+
var_9: T.int16 = var_9
91+
var_10: T.int32 = var_10
92+
var_11: T.int64 = var_11
93+
var_12: T.uint8 = var_12
94+
var_13: T.uint16 = var_13
95+
var_14: T.uint32 = var_14
96+
var_15: T.uint64 = var_15
97+
var_16: T.float8_e4m3fn = var_16
98+
var_17: T.float8_e4m3fnuz = var_17
99+
var_18: T.float8_e5m2 = var_18
100+
var_19: T.float8_e5m2fnuz = var_19
101+
var_20: T.float8_e8m0fnu = var_20
102+
var_21: T.float16 = var_21
103+
var_22: T.bfloat16 = var_22
104+
var_23: T.float32 = var_23
105+
var_24: T.float64 = var_24
106+
107+
s = test_var_decl_sugar.script()
108+
for i in range(1, 25):
109+
assert f'var_{i}_1' in s
110+
assert f'tl.local_var_init' in s
111+
112+
def test_dtype_str_repr():
113+
@T.prim_func
114+
def test_str_repr():
115+
buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared')
116+
buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared')
117+
buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared')
118+
buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared')
119+
buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared')
120+
buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared')
121+
buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared')
122+
buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared')
123+
buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared')
124+
buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared')
125+
buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared')
126+
buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared')
127+
buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared')
128+
buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared')
129+
buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared')
130+
buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared')
131+
buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared')
132+
buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared')
133+
buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared')
134+
buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared')
135+
buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared')
136+
buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared')
137+
buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared')
138+
buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared')
139+
140+
def test_torch_eq():
141+
dtypes = [
142+
T.bool,
143+
T.short,
144+
T.int,
145+
T.long,
146+
T.half,
147+
T.float,
148+
T.long,
149+
T.int8,
150+
T.int16,
151+
T.int32,
152+
T.int64,
153+
T.uint8,
154+
T.uint16,
155+
T.uint32,
156+
T.uint64,
157+
T.float8_e4m3fn,
158+
T.float8_e4m3fnuz,
159+
T.float8_e5m2,
160+
T.float8_e5m2fnuz,
161+
T.float8_e8m0fnu,
162+
T.float16,
163+
T.bfloat16,
164+
T.float32,
165+
T.float64,
166+
]
167+
torch_dtypes = [
168+
torch.bool,
169+
torch.short,
170+
torch.int,
171+
torch.long,
172+
torch.half,
173+
torch.float,
174+
torch.long,
175+
torch.int8,
176+
torch.int16,
177+
torch.int32,
178+
torch.int64,
179+
torch.uint8,
180+
torch.uint16,
181+
torch.uint32,
182+
torch.uint64,
183+
torch.float8_e4m3fn,
184+
torch.float8_e4m3fnuz,
185+
torch.float8_e5m2,
186+
torch.float8_e5m2fnuz,
187+
torch.float8_e8m0fnu,
188+
torch.float16,
189+
torch.bfloat16,
190+
torch.float32,
191+
torch.float64,
192+
]
193+
for a, b in zip(dtypes, torch_dtypes):
194+
assert a == b, f"{a} and {b} are not equal"
195+
196+
197+
def test_var_assign():
198+
@tilelang.jit(out_idx=-1)
199+
@T.prim_func
200+
def test_var_assign(A: T.Tensor((2,), T.int32)):
201+
with T.Kernel(1) as _:
202+
a: T.int32 = 1
203+
b: T.int32 = a
204+
a = 2
205+
d: T.int32 = a
206+
A[0] = b
207+
A[1] = d
208+
res = test_var_assign()()
209+
assert res[0] == 1
210+
assert res[1] == 2
211+
212+
213+
if __name__ == '__main__':
214+
tilelang.testing.main()

testing/python/transform/test_tilelang_transform_multi_version_buffer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def before(scales: T.Tensor((4,), "float32")):
113113
shared = T.alloc_buffer((8,), "float32", scope="shared.dyn")
114114
accum = T.alloc_buffer((8,), "float32", scope="local")
115115
for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
116-
value: T.float32 = scales[k]
116+
value = scales[k]
117117
for i in T.serial(8):
118118
shared[i] = value
119119
for i in T.serial(8):
@@ -125,7 +125,7 @@ def after(scales: T.Tensor((4,), "float32")):
125125
shared = T.alloc_buffer((2, 8), "float32", scope="shared.dyn")
126126
accum = T.alloc_buffer((8,), "float32", scope="local")
127127
for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
128-
value: T.float32 = scales[k]
128+
value = scales[k]
129129
for i in T.serial(8):
130130
shared[k % 2, i] = value
131131
for i in T.serial(8):

0 commit comments

Comments
 (0)