@@ -193,7 +193,6 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared,
193193 B_dequantize_shared [index // block_K ,
194194 index % block_K ] = B_dequantize_local_thread [v ]
195195
196-
197196 return fast_dequant_bf16_fp4_twiddling
198197
199198 def get_simple_dequant_func (in_dtype = "fp4" , out_dtype = "bfloat16" ):
@@ -260,8 +259,7 @@ def main(
260259 if threads == 512 :
261260 T .disable_warp_group_reg_alloc ()
262261
263- T .copy (sorted_token_ids [by * block_M :(by + 1 ) * block_M ],
264- sorted_token_ids_shared )
262+ T .copy (sorted_token_ids [by * block_M :(by + 1 ) * block_M ], sorted_token_ids_shared )
265263 expert_id [0 ] = expert_ids [by ]
266264
267265 # Get the topk weights of each token in the current block
@@ -287,7 +285,8 @@ def main(
287285 if sorted_token_ids_shared [i ] != - 1 :
288286 A_shared [i , j ] = A [sorted_token_ids_shared [i ] // topk , k * block_K + j ]
289287 if fast_dequant :
290- get_fast_dequant_twiddling_func ()(B_shared , B_dequantize_shared , Scale_shared , k )
288+ get_fast_dequant_twiddling_func ()(B_shared , B_dequantize_shared , Scale_shared ,
289+ k )
291290 else :
292291 get_simple_dequant_func ()(B_shared , B_dequantize_shared , Scale_shared , k )
293292
@@ -300,7 +299,7 @@ def main(
300299 for i , j in T .Parallel (block_M , block_N ):
301300 if sorted_token_ids_shared [i ] != - 1 :
302301 C [sorted_token_ids_shared [i ] // topk , sorted_token_ids_shared [i ] % topk ,
303- bx * block_N + j ] = C_shared [i , j ]
302+ bx * block_N + j ] = C_shared [i , j ]
304303
305304 return main
306305
@@ -397,20 +396,13 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M):
397396 return A , qB , Scale , Bias , topk_weights , sorted_token_ids , expert_ids , padding_M
398397
399398
400- def main (m = 256 ,
401- n = 256 ,
402- k = 256 ,
403- scale_size = 32 ,
404- fast_dequant = True ,
405- with_bias = False ,
406- topk = 4 ,
407- E = 32 ):
399+ def main (m = 256 , n = 256 , k = 256 , scale_size = 32 , fast_dequant = True , with_bias = False , topk = 4 , E = 32 ):
408400 # Tunable parameters
409401 block_M , block_N , block_K = 128 , 128 , 256
410402 num_stages = 2
411403 threads = 512
412404 split = 1
413-
405+
414406 total_flops = 2 * m * n * k
415407 num_bits = 4
416408 num_elems_per_byte = 8 // num_bits
@@ -453,7 +445,8 @@ def main(m=256,
453445 A , qB , Scale , Bias , topk_weights , sorted_token_ids , expert_ids , block_M = block_M )
454446
455447 print ("All checks pass. ✅" )
456- latency = tilelang .profiler .do_bench (lambda : kernel (A , qB , Scale , Bias , topk_weights , sorted_token_ids , expert_ids ), warmup = 500 )
448+ latency = tilelang .profiler .do_bench (
449+ lambda : kernel (A , qB , Scale , Bias , topk_weights , sorted_token_ids , expert_ids ), warmup = 500 )
457450 print ("Tile-lang: {:.2f} ms" .format (latency ))
458451 print ("Tile-lang: {:.2f} TFlops" .format (total_flops / latency * 1e-9 ))
459452
@@ -463,7 +456,7 @@ def main(m=256,
463456 print (f"max abs diff: { max_val } at index: { max_idx } " )
464457 assert_similar (output , ref_output , name = "output" , eps = 1e-5 )
465458
466-
459+
467460if __name__ == "__main__" :
468461 M , N , K = 1024 , 2944 , 3072 # From gpt-oss-20b
469462 scale_size = 32
0 commit comments