@@ -365,6 +365,7 @@ def benchmark(
365365 dtype : torch .dtype ,
366366 use_fp8_w8a8 : bool ,
367367 use_int8_w8a16 : bool ,
368+ block_quant_shape : List [int ] = None ,
368369 ) -> tuple [dict [str , int ], float ]:
369370 current_platform .seed_everything (self .seed )
370371 dtype_str = get_config_dtype_str (dtype ,
@@ -385,10 +386,17 @@ def benchmark(
385386 else :
386387 config = op_config [min (op_config .keys (),
387388 key = lambda x : abs (x - num_tokens ))]
388- kernel_time = benchmark_config (config , num_tokens , num_experts ,
389- shard_intermediate_size , hidden_size ,
390- topk , dtype , use_fp8_w8a8 ,
391- use_int8_w8a16 )
389+ kernel_time = benchmark_config (config ,
390+ num_tokens ,
391+ num_experts ,
392+ shard_intermediate_size ,
393+ hidden_size ,
394+ topk ,
395+ dtype ,
396+ use_fp8_w8a8 ,
397+ use_int8_w8a16 ,
398+ num_iters = 100 ,
399+ block_quant_shape = block_quant_shape )
392400 return config , kernel_time
393401
394402 def tune (
@@ -487,6 +495,14 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
487495 f .write ("\n " )
488496
489497
498+ def get_weight_block_size_safety (config , default_value = None ):
499+
500+ quantization_config = getattr (config , 'quantization_config' , {})
501+ if isinstance (quantization_config , dict ):
502+ return quantization_config .get ('weight_block_size' , default_value )
503+ return default_value
504+
505+
490506def main (args : argparse .Namespace ):
491507 print (args )
492508 block_quant_shape = None
@@ -508,7 +524,7 @@ def main(args: argparse.Namespace):
508524 topk = config .num_experts_per_tok
509525 intermediate_size = config .moe_intermediate_size
510526 shard_intermediate_size = 2 * intermediate_size // args .tp_size
511- block_quant_shape = config . quantization_config [ 'weight_block_size' ]
527+ block_quant_shape = get_weight_block_size_safety ( config )
512528 elif config .architectures [0 ] == "Qwen2MoeForCausalLM" :
513529 E = config .num_experts
514530 topk = config .num_experts_per_tok
0 commit comments