44import tilelang .language as T
55import math
66
7- tilelang .disable_cache ()
8-
97
108def torch_gmm (a , b , batch_sizes , batch_offsets_tensor , trans_b = False ):
119 """
@@ -57,6 +55,7 @@ def grouped_gemm(batch_sizes_list,
5755 batch_sum = sum (batch_sizes_list )
5856 batch_count = len (batch_sizes_list )
5957 accum_dtype = "float32"
58+ total_m_blocks = sum ((size + block_M - 1 ) // block_M for size in batch_sizes_list )
6059
6160 @T .prim_func
6261 def kernel (
@@ -68,9 +67,7 @@ def kernel(
6867 batch_padded_offsets : T .Tensor ([batch_count ], "int32" ), # type: ignore
6968 ):
7069
71- with T .Kernel (
72- T .ceildiv (batch_sum , block_M ) + batch_count , T .ceildiv (N , block_N ),
73- threads = threads ) as (bx , by ):
70+ with T .Kernel (total_m_blocks , T .ceildiv (N , block_N ), threads = threads ) as (bx , by ):
7471 A_shared = T .alloc_shared ([block_M , block_K ], dtype )
7572 B_shared = T .alloc_shared ([block_K , block_N ], dtype )
7673 C_local = T .alloc_fragment ([block_M , block_N ], accum_dtype )
@@ -115,8 +112,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
115112 batch_offsets_list .append (batch_offsets_list [- 1 ] + batch_sizes_list [i ])
116113 for i in range (batch_count - 1 ):
117114 batch_padded_offsets_list .append (batch_padded_offsets_list [- 1 ] +
118- math .ceil ((batch_sizes_list [i ] + 1 ) / padding_M ) *
119- padding_M )
115+ math .ceil ((batch_sizes_list [i ]) / padding_M ) * padding_M )
120116 A = torch .randn (batch_sum , K , device = device , dtype = dtype )
121117 B = torch .randn (batch_count , K , M , device = device , dtype = dtype )
122118 C = torch .empty (batch_sum , M , device = device , dtype = dtype )
0 commit comments