@@ -204,8 +204,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
204204 def __init__ (self , moe : FusedMoEConfig ):
205205 super ().__init__ (moe )
206206 self .mxfp4_backend = get_mxfp4_backend (moe .is_lora_enabled )
207- self .max_capture_size = (
208- get_current_vllm_config ().compilation_config .max_cudagraph_capture_size
207+ # Be conservative and tune for the most extreme inbalance for MoE,
208+ # i.e., one expert receives all the tokens.
209+ self .tune_max_num_tokens = (
210+ get_current_vllm_config ().scheduler_config .max_num_batched_tokens
209211 )
210212
211213 assert self .mxfp4_backend != Mxfp4Backend .NONE , (
@@ -845,7 +847,7 @@ def select_gemm_impl(
845847 "gemm1_beta" : layer .gemm1_beta ,
846848 "gemm1_clamp_limit" : layer .gemm1_clamp_limit ,
847849 # TODO(bnell): part of quant_config
848- "max_capture_size " : self .max_capture_size ,
850+ "tune_max_num_tokens " : self .tune_max_num_tokens ,
849851 }
850852 return TrtLlmGenExperts (self .moe , self .moe_quant_config , ** kwargs )
851853 elif self .mxfp4_backend == Mxfp4Backend .MARLIN :
@@ -981,7 +983,7 @@ def apply(
981983 None ,
982984 1 if renormalize else 0 , # routing_method_type, renormalize
983985 True , # do finalize
984- tune_max_num_tokens = max ( self .max_capture_size , 1 ) ,
986+ tune_max_num_tokens = self .tune_max_num_tokens ,
985987 )[0 ]
986988 return trtllm_gen_output
987989 elif (
@@ -1056,7 +1058,7 @@ def apply(
10561058 tp_rank = self .moe .tp_rank ,
10571059 ep_size = self .moe .ep_size ,
10581060 ep_rank = self .moe .ep_rank ,
1059- tune_max_num_tokens = max ( self .max_capture_size , 1 ) ,
1061+ tune_max_num_tokens = self .tune_max_num_tokens ,
10601062 ** extra_kwargs ,
10611063 )
10621064
0 commit comments