Skip to content

Commit 95170ab

Browse files
authored
[Enhancement] Fix lint to improve grouped GEMM performance with TMA (#938)
* [Example] Fix lint to improve grouped GEMM performance with TMA * fix lint
1 parent b31de0c commit 95170ab

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

examples/grouped_gemm/example_grouped_gemm_fwd.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import tilelang.language as T
55
import math
66

7-
tilelang.disable_cache()
8-
97

108
def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
119
"""
@@ -57,6 +55,7 @@ def grouped_gemm(batch_sizes_list,
5755
batch_sum = sum(batch_sizes_list)
5856
batch_count = len(batch_sizes_list)
5957
accum_dtype = "float32"
58+
total_m_blocks = sum((size + block_M - 1) // block_M for size in batch_sizes_list)
6059

6160
@T.prim_func
6261
def kernel(
@@ -68,9 +67,7 @@ def kernel(
6867
batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore
6968
):
7069

71-
with T.Kernel(
72-
T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N),
73-
threads=threads) as (bx, by):
70+
with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by):
7471
A_shared = T.alloc_shared([block_M, block_K], dtype)
7572
B_shared = T.alloc_shared([block_K, block_N], dtype)
7673
C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
@@ -115,8 +112,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
115112
batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i])
116113
for i in range(batch_count - 1):
117114
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] +
118-
math.ceil((batch_sizes_list[i] + 1) / padding_M) *
119-
padding_M)
115+
math.ceil((batch_sizes_list[i]) / padding_M) * padding_M)
120116
A = torch.randn(batch_sum, K, device=device, dtype=dtype)
121117
B = torch.randn(batch_count, K, M, device=device, dtype=dtype)
122118
C = torch.empty(batch_sum, M, device=device, dtype=dtype)

0 commit comments

Comments
 (0)