Skip to content

Commit a78e640

Browse files
committed
let test fix
1 parent 6bdecec commit a78e640

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import tilelang.testing
2+
from tilelang import tvm as tvm
3+
from tvm import IRModule
4+
from tilelang import language as T
5+
from tilelang.utils.tensor import map_torch_type
6+
7+
def test_let_vectorize_load():
8+
@T.prim_func
9+
def main(A_ptr: T.handle):
10+
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
11+
12+
for blockIdx in T.thread_binding(1, thread="blockIdx.x"):
13+
for threadIdx in T.thread_binding(128, thread="threadIdx.x"):
14+
b: T.float32x4 = A[0, 0:4]
15+
A[0, 4:8] = b
16+
17+
mod = tvm.IRModule({"main": main})
18+
mod = tvm.compile(mod, target="cuda")
19+
assert "float4 b" in mod.mod.imported_modules[0].get_source()
20+
21+
22+
if __name__ == "__main__":
23+
tilelang.testing.main()

0 commit comments

Comments
 (0)