44import tilelang .language as T
55import math
66
7- tilelang .disable_cache ()
7+ # tilelang.disable_cache()
88
99
1010def torch_gmm (a , b , batch_sizes , batch_offsets_tensor , trans_b = False ):
@@ -57,6 +57,7 @@ def grouped_gemm(batch_sizes_list,
5757 batch_sum = sum (batch_sizes_list )
5858 batch_count = len (batch_sizes_list )
5959 accum_dtype = "float32"
60+ total_m_blocks = sum ((size + block_M - 1 ) // block_M for size in batch_sizes_list )
6061
6162 @T .prim_func
6263 def kernel (
@@ -68,9 +69,7 @@ def kernel(
6869 batch_padded_offsets : T .Tensor ([batch_count ], "int32" ), # type: ignore
6970 ):
7071
71- with T .Kernel (
72- T .ceildiv (batch_sum , block_M ) + batch_count , T .ceildiv (N , block_N ),
73- threads = threads ) as (bx , by ):
72+ with T .Kernel (total_m_blocks , T .ceildiv (N , block_N ), threads = threads ) as (bx , by ):
7473 A_shared = T .alloc_shared ([block_M , block_K ], dtype )
7574 B_shared = T .alloc_shared ([block_K , block_N ], dtype )
7675 C_local = T .alloc_fragment ([block_M , block_N ], accum_dtype )
@@ -115,8 +114,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
115114 batch_offsets_list .append (batch_offsets_list [- 1 ] + batch_sizes_list [i ])
116115 for i in range (batch_count - 1 ):
117116 batch_padded_offsets_list .append (batch_padded_offsets_list [- 1 ] +
118- math .ceil ((batch_sizes_list [i ] + 1 ) / padding_M ) *
119- padding_M )
117+ math .ceil ((batch_sizes_list [i ]) / padding_M ) * padding_M )
120118 A = torch .randn (batch_sum , K , device = device , dtype = dtype )
121119 B = torch .randn (batch_count , K , M , device = device , dtype = dtype )
122120 C = torch .empty (batch_sum , M , device = device , dtype = dtype )
0 commit comments