Skip to content

Commit ce4f6a9

Browse files
committed
use maximum number of batched tokens to autotune
Signed-off-by: Julien Lin <jullin@nvidia.com>
1 parent 428bc7b commit ce4f6a9

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

vllm/model_executor/layers/fused_moe/trtllm_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ def __init__(
2121
gemm1_alpha,
2222
gemm1_beta,
2323
gemm1_clamp_limit,
24-
max_capture_size,
24+
tune_max_num_tokens,
2525
):
2626
super().__init__(quant_config)
2727
self.moe = moe
2828
self.gemm1_alpha = gemm1_alpha
2929
self.gemm1_beta = gemm1_beta
3030
self.gemm1_clamp_limit = gemm1_clamp_limit
31-
self.max_capture_size = max_capture_size
31+
self.tune_max_num_tokens = tune_max_num_tokens
3232

3333
@property
3434
def activation_formats(
@@ -127,7 +127,7 @@ def apply(
127127
"routing_method_type": 1,
128128
"do_finalize": True,
129129
"output": output,
130-
"tune_max_num_tokens": max(self.max_capture_size, 1),
130+
"tune_max_num_tokens": self.tune_max_num_tokens,
131131
}
132132

133133
from flashinfer import trtllm_fp4_block_scale_routed_moe

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
204204
def __init__(self, moe: FusedMoEConfig):
205205
super().__init__(moe)
206206
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
207-
self.max_capture_size = (
208-
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
207+
# Be conservative and tune for the most extreme inbalance for MoE,
208+
# i.e., one expert receives all the tokens.
209+
self.tune_max_num_tokens = (
210+
get_current_vllm_config().scheduler_config.max_num_batched_tokens
209211
)
210212

211213
assert self.mxfp4_backend != Mxfp4Backend.NONE, (
@@ -845,7 +847,7 @@ def select_gemm_impl(
845847
"gemm1_beta": layer.gemm1_beta,
846848
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
847849
# TODO(bnell): part of quant_config
848-
"max_capture_size": self.max_capture_size,
850+
"tune_max_num_tokens": self.tune_max_num_tokens,
849851
}
850852
return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
851853
elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
@@ -981,7 +983,7 @@ def apply(
981983
None,
982984
1 if renormalize else 0, # routing_method_type, renormalize
983985
True, # do finalize
984-
tune_max_num_tokens=max(self.max_capture_size, 1),
986+
tune_max_num_tokens=self.tune_max_num_tokens,
985987
)[0]
986988
return trtllm_gen_output
987989
elif (
@@ -1056,7 +1058,7 @@ def apply(
10561058
tp_rank=self.moe.tp_rank,
10571059
ep_size=self.moe.ep_size,
10581060
ep_rank=self.moe.ep_rank,
1059-
tune_max_num_tokens=max(self.max_capture_size, 1),
1061+
tune_max_num_tokens=self.tune_max_num_tokens,
10601062
**extra_kwargs,
10611063
)
10621064

0 commit comments

Comments
 (0)