@@ -18,9 +18,9 @@ def main(
1818 ):
1919 with T .Kernel (
2020 T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), split_k , threads = 128 ) as (bx , by , bz ):
21- A_shared = T .alloc_shared ((block_M , block_K ), dtype , "shared" )
22- B_shared = T .alloc_shared ((block_K , block_N ), dtype , "shared" )
23- C_shared = T .alloc_shared ((block_M , block_N ), dtype , "shared" )
21+ A_shared = T .alloc_shared ((block_M , block_K ), dtype )
22+ B_shared = T .alloc_shared ((block_K , block_N ), dtype )
23+ C_shared = T .alloc_shared ((block_M , block_N ), dtype )
2424 C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
2525
2626 if bz == 0 :
@@ -42,9 +42,9 @@ def main(
4242 m , n = by * block_M + i , bx * block_N + j * 2
4343 # vectorized atomic
4444 T .atomic_addx2 (C [m , n ], C_shared [i , j * 2 ])
45- else :
46- for i , j in T .Parallel (block_M , block_N ):
47- T .atomic_add (C [by * block_M + i , bx * block_N + j ], C_shared [i , j ])
45+ else :
46+ for i , j in T .Parallel (block_M , block_N ):
47+ T .atomic_add (C [by * block_M + i , bx * block_N + j ], C_shared [i , j ])
4848
4949 return main
5050
0 commit comments