@@ -9,6 +9,7 @@ def _check(original, transformed):
99 mod = tvm .IRModule .from_expr (func .with_attr ("global_symbol" , "main" ))
1010 mod = tl .transform .InjectSoftwarePipeline ()(mod )
1111 mod = tl .transform .Simplify ()(mod )
12+ mod = tl .transform .LowerOpaqueBlock ()(mod )
1213 tvm .ir .assert_structural_equal (mod ["main" ], transformed .with_attr ("global_symbol" , "main" ),
1314 True )
1415
@@ -39,35 +40,20 @@ def before(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")):
3940 C [tx , i ] = B [tx , 0 ] + T .float32 (1 )
4041
4142 @T .prim_func
42- def expected (A : T .Buffer ((16 , 1 ), "float32" ), C : T .Buffer ((16 , 1 ), "float32" )):
43- for tx in T .thread_binding (16 , thread = "threadIdx.x" ):
44- with T .block ():
45- T .reads (A [tx , 0 ])
46- T .writes (C [tx , 0 ])
47- B = T .alloc_buffer ((2 , 16 , 1 ), scope = "shared" )
48- with T .block ():
49- T .reads (A [tx , 0 ])
50- T .writes (B [0 , tx , 0 ])
51- B [0 , tx , 0 ] = A [tx , 0 ] * T .float32 (2.0 )
52- with T .block ():
53- T .reads (A [tx , 1 :1 ], B [0 :2 , tx , 0 ])
54- T .writes (B [1 :1 , tx , 0 ], C [tx , 0 :0 ])
55- for i in range (0 ):
56- with T .block ():
57- T .reads (A [tx , i + 1 ])
58- T .writes (B [i + 1 , tx , 0 ])
59- B [i + 1 , tx , 0 ] = A [tx , i + 1 ] * T .float32 (2.0 )
60- with T .block ():
61- T .reads (B [i , tx , 0 ])
62- T .writes (C [tx , i ])
63- C [tx , i ] = B [i , tx , 0 ] + T .float32 (1.0 )
64- with T .block ():
65- T .reads (B [0 , tx , 0 ])
66- T .writes (C [tx , 0 ])
67- C [tx , 0 ] = B [0 , tx , 0 ] + T .float32 (1.0 )
43+ def expected (A_handle : T .handle , C_handle : T .handle ):
44+ A = T .match_buffer (A_handle , (16 , 1 ), strides = (1 , 1 ))
45+ C = T .match_buffer (C_handle , (16 , 1 ), strides = (1 , 1 ))
46+ tx = T .launch_thread ("threadIdx.x" , 16 )
47+ B = T .decl_buffer ((2 , 16 , 1 ), scope = "shared" )
48+ B [0 , tx , 0 ] = A [tx , 0 ] * T .float32 (2.0 )
49+ for i in range (0 ):
50+ B [i + 1 , tx , 0 ] = A [tx , i + 1 ] * T .float32 (2.0 )
51+ C [tx , i ] = B [i , tx , 0 ] + T .float32 (1.0 )
52+ C [tx , 0 ] = B [0 , tx , 0 ] + T .float32 (1.0 )
6853
6954 _check (before , expected )
7055
7156
7257if __name__ == "__main__" :
73- tilelang .testing .main ()
58+ # tilelang.testing.main()
59+ test_trival_pipeline ()
0 commit comments