1111def get_configs ():
1212 """
1313 Generate a list of hyperparameter configuration dictionaries for tuning.
14-
14+
1515 Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K',
1616 'num_stages', 'threads', and 'split'. The function returns the Cartesian
1717 product of the parameter value lists:
1818 - block_M, block_N, block_K: tiling sizes
1919 - num_stages: pipeline stages
2020 - threads: thread counts
2121 - split: K-splitting factor
22-
22+
2323 Returns:
2424 List[dict]: A list of configuration dictionaries covering all combinations.
2525 """
@@ -309,17 +309,20 @@ def main(
309309 C_local [i , j ] = Bias_shared [j ]
310310
311311 tx = T .get_thread_binding ()
312-
312+
313313 for k in T .Pipelined (K // block_K , num_stages = num_stages ):
314314 for copy_i in T .serial (block_M * block_K // threads // 16 ):
315315 base = copy_i * threads * 16 + tx * 16
316316 if sorted_token_ids_shared [base // block_K ] != - 1 :
317317 for copy_j in T .vectorized (16 ):
318- A_shared [base // block_K , base % block_K + copy_j ] = A [sorted_token_ids_shared [base // block_K ] // topk , k * block_K + base % block_K + copy_j ]
318+ A_shared [base // block_K , base % block_K +
319+ copy_j ] = A [sorted_token_ids_shared [base // block_K ] // topk ,
320+ k * block_K + base % block_K + copy_j ]
319321
320322 T .copy (B [expert_id [0 ], bx * block_N , k * block_K // num_elems_per_byte ], B_shared )
321323 if fast_dequant :
322- get_fast_dequant_twiddling_func ()(B_shared , B_dequantize_shared , Scale_shared , k )
324+ get_fast_dequant_twiddling_func ()(B_shared , B_dequantize_shared , Scale_shared ,
325+ k )
323326 else :
324327 get_simple_dequant_func ()(B_shared , B_dequantize_shared , Scale_shared , k )
325328
@@ -331,7 +334,7 @@ def main(
331334 T .copy (C_local , C_shared )
332335 for i , j in T .Parallel (block_M , block_N ):
333336 C [sorted_token_ids_shared [i ] // topk , sorted_token_ids_shared [i ] % topk ,
334- bx * block_N + j ] = C_shared [i , j ]
337+ bx * block_N + j ] = C_shared [i , j ]
335338
336339 return main
337340
@@ -366,7 +369,8 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc
366369
367370 # Compute the output for this token-expert pair
368371 # token_embedding @ B.T + bias
369- output = torch .matmul (token_embedding .to (torch .bfloat16 ), B .T .to (torch .bfloat16 )) + Bias [expert_id ]
372+ output = torch .matmul (token_embedding .to (torch .bfloat16 ), B .T .to (
373+ torch .bfloat16 )) + Bias [expert_id ]
370374 output = output .to (torch .__getattribute__ (dtypeC ))
371375
372376 # Apply the topk weight
@@ -491,7 +495,9 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
491495 max_val = diff .max ()
492496 max_idx = diff .argmax ()
493497 print (f"max abs diff: { max_val } at index: { max_idx } " )
494- assert_similar (output , ref_output , name = "output" , eps = 1e-5 ) # We care about the similarity rather than abs. difference
498+ assert_similar (
499+ output , ref_output , name = "output" ,
500+ eps = 1e-5 ) # We care about the similarity rather than abs. difference
495501 print ("All checks pass. ✅" )
496502
497503
0 commit comments