Skip to content

Commit 7f6efe6

Browse files
committed
Add unit test.
1 parent 4ee7b24 commit 7f6efe6

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import tilelang
2+
import tilelang.language as T
3+
4+
5+
@tilelang.jit
6+
def fill_symbolic(value: float, dtype="bfloat16"):
7+
n = T.symbolic("n", "int64")
8+
block_n = 512
9+
10+
@T.prim_func
11+
def main(x: T.Tensor[n, dtype]):
12+
# Initialize Kernel Context
13+
with T.Kernel(T.ceildiv(n, block_n), threads=128) as bx:
14+
# Doesn't yet work with int64-shaped global tensor
15+
# T.fill(x[bx * block_n : (bx + 1) * block_n], value)
16+
for i in T.Parallel(block_n):
17+
x[bx * block_n + i] = value
18+
19+
return main
20+
21+
22+
def run_fill_symbolic(n: int):
23+
import torch
24+
25+
x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
26+
fill_symbolic(1.0)(x)
27+
assert x.min() == 1.0 and x.max() == 1.0
28+
29+
30+
def test_fill_symbolic():
31+
# Requires 8GB VRAM
32+
run_fill_symbolic(2**32)
33+
34+
35+
@tilelang.jit
36+
def fill_static(n: int, value: float, dtype="bfloat16"):
37+
block_n = 512
38+
39+
@T.prim_func
40+
def main(x: T.Tensor[n, dtype]):
41+
# Initialize Kernel Context
42+
with T.Kernel(T.ceildiv(n, block_n), threads=128) as bx:
43+
# Doesn't yet work with int64-shaped global tensor
44+
# T.fill(x[bx * block_n : (bx + 1) * block_n], value)
45+
for i in T.Parallel(block_n):
46+
x[bx * block_n + i] = value
47+
48+
return main
49+
50+
51+
def run_fill_static(n: int):
52+
import torch
53+
54+
x = torch.zeros(n, dtype=torch.bfloat16, device="cuda")
55+
fill_static(n, 1.0)(x)
56+
assert x.min() == 1.0 and x.max() == 1.0
57+
58+
59+
def test_fill_static():
60+
# Requires 8GB VRAM
61+
run_fill_static(2**32)
62+
63+
64+
if __name__ == "__main__":
65+
test_fill_symbolic()
66+
test_fill_static()

0 commit comments

Comments
 (0)