@@ -382,19 +382,19 @@ def add_lora_fused_moe(
382382 num_tokens_post_padded ,
383383 max_lora_rank ,
384384 top_k_num ,
385- shrink_config .get ("BLOCK_SIZE_M" , 64 ) ,
386- shrink_config .get ("BLOCK_SIZE_N" , 64 ) ,
387- shrink_config .get ("BLOCK_SIZE_K" , 32 ) ,
388- shrink_config .get ("GROUP_SIZE_M" , 8 ) ,
389- shrink_config .get ("NUM_WARPS" , 4 ) ,
390- shrink_config .get ("NUM_STAGES" , 3 ) ,
391- shrink_config .get ("SPLIT_K" , 1 ) ,
392- expand_config .get ("BLOCK_SIZE_M" , 64 ) ,
393- expand_config .get ("BLOCK_SIZE_N" , 64 ) ,
394- expand_config .get ("BLOCK_SIZE_K" , 64 ) ,
395- expand_config .get ("GROUP_SIZE_M" , 64 ) ,
396- expand_config .get ("NUM_WARPS" , 4 ) ,
397- expand_config .get ("NUM_STAGES" , 3 ) ,
398- expand_config .get ("SPLIT_K" , 1 ) ,
385+ shrink_config .get ("BLOCK_SIZE_M" ) or shrink_config . get ( "block_m" ) or 64 ,
386+ shrink_config .get ("BLOCK_SIZE_N" ) or shrink_config . get ( "block_n" ) or 64 ,
387+ shrink_config .get ("BLOCK_SIZE_K" ) or shrink_config . get ( "block_k" ) or 32 ,
388+ shrink_config .get ("GROUP_SIZE_M" ) or shrink_config . get ( "group_m" ) or 8 ,
389+ shrink_config .get ("NUM_WARPS" ) or shrink_config . get ( "num_warps" ) or 4 ,
390+ shrink_config .get ("NUM_STAGES" ) or shrink_config . get ( "num_stages" ) or 3 ,
391+ shrink_config .get ("SPLIT_K" ) or shrink_config . get ( "split_k" ) or 1 ,
392+ expand_config .get ("BLOCK_SIZE_M" ) or expand_config . get ( "block_m" ) or 64 ,
393+ expand_config .get ("BLOCK_SIZE_N" ) or expand_config . get ( "block_n" ) or 64 ,
394+ expand_config .get ("BLOCK_SIZE_K" ) or expand_config . get ( "block_k" ) or 64 ,
395+ expand_config .get ("GROUP_SIZE_M" ) or expand_config . get ( "group_m" ) or 64 ,
396+ expand_config .get ("NUM_WARPS" ) or expand_config . get ( "num_warps" ) or 4 ,
397+ expand_config .get ("NUM_STAGES" ) or expand_config . get ( "num_stages" ) or 3 ,
398+ expand_config .get ("SPLIT_K" ) or expand_config . get ( "split_k" ) or 1 ,
399399 mul_routed_weight ,
400400 )
0 commit comments