Skip to content

Commit 5a733ec

Browse files
committed
[Example] Fix lint to improve grouped GEMM performance with TMA
1 parent f691e79 commit 5a733ec

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

examples/grouped_gemm/example_grouped_gemm_fwd.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tilelang.language as T
55
import math
66

7-
tilelang.disable_cache()
7+
# tilelang.disable_cache()
88

99

1010
def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
@@ -57,6 +57,7 @@ def grouped_gemm(batch_sizes_list,
5757
batch_sum = sum(batch_sizes_list)
5858
batch_count = len(batch_sizes_list)
5959
accum_dtype = "float32"
60+
total_m_blocks = sum((size + block_M - 1) // block_M for size in batch_sizes_list)
6061

6162
@T.prim_func
6263
def kernel(
@@ -68,9 +69,7 @@ def kernel(
6869
batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore
6970
):
7071

71-
with T.Kernel(
72-
T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N),
73-
threads=threads) as (bx, by):
72+
with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by):
7473
A_shared = T.alloc_shared([block_M, block_K], dtype)
7574
B_shared = T.alloc_shared([block_K, block_N], dtype)
7675
C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
@@ -115,8 +114,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
115114
batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i])
116115
for i in range(batch_count - 1):
117116
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] +
118-
math.ceil((batch_sizes_list[i] + 1) / padding_M) *
119-
padding_M)
117+
math.ceil((batch_sizes_list[i]) / padding_M) * padding_M)
120118
A = torch.randn(batch_sum, K, device=device, dtype=dtype)
121119
B = torch.randn(batch_count, K, M, device=device, dtype=dtype)
122120
C = torch.empty(batch_sum, M, device=device, dtype=dtype)

0 commit comments

Comments
 (0)