Skip to content

Commit 3952de3

Browse files
authored
[Typo] Fix a typo in gemm splitk examples (#111)
1 parent e055782 commit 3952de3

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

examples/gemm_splitk/example_tilelang_gemm_splitk.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ def main(
1818
):
1919
with T.Kernel(
2020
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
21-
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared")
22-
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared")
23-
C_shared = T.alloc_shared((block_M, block_N), dtype, "shared")
21+
A_shared = T.alloc_shared((block_M, block_K), dtype)
22+
B_shared = T.alloc_shared((block_K, block_N), dtype)
23+
C_shared = T.alloc_shared((block_M, block_N), dtype)
2424
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
2525

2626
if bz == 0:
@@ -42,9 +42,9 @@ def main(
4242
m, n = by * block_M + i, bx * block_N + j * 2
4343
# vectorized atomic
4444
T.atomic_addx2(C[m, n], C_shared[i, j * 2])
45-
else:
46-
for i, j in T.Parallel(block_M, block_N):
47-
T.atomic_add(C[by * block_M + i, bx * block_N + j], C_shared[i, j])
45+
else:
46+
for i, j in T.Parallel(block_M, block_N):
47+
T.atomic_add(C[by * block_M + i, bx * block_N + j], C_shared[i, j])
4848

4949
return main
5050

0 commit comments

Comments
 (0)