Skip to content

Commit a73122d

Browse files
authored
[Bugfix] fix benchmark moe (#14653)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent bd44b81 commit a73122d

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def benchmark(
365365
dtype: torch.dtype,
366366
use_fp8_w8a8: bool,
367367
use_int8_w8a16: bool,
368+
block_quant_shape: List[int] = None,
368369
) -> tuple[dict[str, int], float]:
369370
current_platform.seed_everything(self.seed)
370371
dtype_str = get_config_dtype_str(dtype,
@@ -385,10 +386,17 @@ def benchmark(
385386
else:
386387
config = op_config[min(op_config.keys(),
387388
key=lambda x: abs(x - num_tokens))]
388-
kernel_time = benchmark_config(config, num_tokens, num_experts,
389-
shard_intermediate_size, hidden_size,
390-
topk, dtype, use_fp8_w8a8,
391-
use_int8_w8a16)
389+
kernel_time = benchmark_config(config,
390+
num_tokens,
391+
num_experts,
392+
shard_intermediate_size,
393+
hidden_size,
394+
topk,
395+
dtype,
396+
use_fp8_w8a8,
397+
use_int8_w8a16,
398+
num_iters=100,
399+
block_quant_shape=block_quant_shape)
392400
return config, kernel_time
393401

394402
def tune(
@@ -487,6 +495,14 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
487495
f.write("\n")
488496

489497

498+
def get_weight_block_size_safety(config, default_value=None):
499+
500+
quantization_config = getattr(config, 'quantization_config', {})
501+
if isinstance(quantization_config, dict):
502+
return quantization_config.get('weight_block_size', default_value)
503+
return default_value
504+
505+
490506
def main(args: argparse.Namespace):
491507
print(args)
492508
block_quant_shape = None
@@ -508,7 +524,7 @@ def main(args: argparse.Namespace):
508524
topk = config.num_experts_per_tok
509525
intermediate_size = config.moe_intermediate_size
510526
shard_intermediate_size = 2 * intermediate_size // args.tp_size
511-
block_quant_shape = config.quantization_config['weight_block_size']
527+
block_quant_shape = get_weight_block_size_safety(config)
512528
elif config.architectures[0] == "Qwen2MoeForCausalLM":
513529
E = config.num_experts
514530
topk = config.num_experts_per_tok

0 commit comments

Comments
 (0)