Skip to content

Commit 9508960

Browse files
authored
[Model][gpt-oss] Support DP+EP for GPT-OSS with FlashInfer trtllm-gen MoE (#23819)
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
1 parent 1f096f9 commit 9508960

File tree

3 files changed

+14
-15
lines changed

3 files changed

+14
-15
lines changed

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,6 @@ def use_deepep_ll_kernels(self):
190190
return (self.use_all2all_kernels
191191
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
192192

193-
@property
194-
def use_flashinfer_cutlass_kernels(self):
195-
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
196-
and has_flashinfer_cutlass_fused_moe()
197-
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
198-
199193
@staticmethod
200194
def make(tp_size_: int, dp_size_: int,
201195
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
@@ -404,7 +398,14 @@ def use_deepep_ll_kernels(self):
404398

405399
@property
406400
def use_flashinfer_cutlass_kernels(self):
407-
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
401+
"""
402+
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
403+
"""
404+
return (self.quant_config is not None
405+
and self.quant_config.quant_dtype == "nvfp4"
406+
and envs.VLLM_USE_FLASHINFER_MOE_FP4
407+
and has_flashinfer_cutlass_fused_moe()
408+
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
408409

409410
@staticmethod
410411
def make(

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ def __init__(
920920
self.batched_router_logits: Optional[torch.Tensor] = None
921921
if (self.moe_parallel_config.use_pplx_kernels
922922
or self.moe_parallel_config.use_deepep_ll_kernels
923-
or self.moe_parallel_config.use_flashinfer_cutlass_kernels):
923+
or self.moe_config.use_flashinfer_cutlass_kernels):
924924
self.batched_hidden_states = torch.zeros(
925925
(moe.max_num_tokens, self.hidden_size),
926926
dtype=moe.in_dtype,
@@ -974,7 +974,7 @@ def use_deepep_ll_kernels(self):
974974

975975
@property
976976
def use_flashinfer_cutlass_kernels(self):
977-
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
977+
return self.moe_config.use_flashinfer_cutlass_kernels
978978

979979
def update_expert_map(self):
980980
# ep_size and ep_rank should already be updated
@@ -1665,7 +1665,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
16651665
# only when data parallelism (DP) is enabled.
16661666
use_flashinfer_cutlass_kernels = (
16671667
self.dp_size > 1
1668-
and self.moe_parallel_config.use_flashinfer_cutlass_kernels)
1668+
and self.moe_config.use_flashinfer_cutlass_kernels)
16691669
if (self.moe_parallel_config.use_pplx_kernels
16701670
or self.moe_parallel_config.use_deepep_ll_kernels
16711671
or use_flashinfer_cutlass_kernels):
@@ -1674,7 +1674,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
16741674
do_naive_dispatch_combine: bool = (
16751675
self.dp_size > 1
16761676
and not self.moe_parallel_config.use_deepep_ht_kernels
1677-
and not self.moe_parallel_config.use_flashinfer_cutlass_kernels)
1677+
and not self.moe_config.use_flashinfer_cutlass_kernels)
16781678
if do_naive_dispatch_combine:
16791679
hidden_states, router_logits = get_ep_group().dispatch(
16801680
hidden_states, router_logits)

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -623,8 +623,6 @@ def apply(
623623

624624
if should_use_flashinfer_mxfp4():
625625
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
626-
assert not self.moe.use_ep, (
627-
"EP is not supported for flashinfer mxfp4 moe backend yet.")
628626
if _should_use_flashinfer_mxfp4_bf16():
629627
assert x.dtype == torch.bfloat16
630628
x_quant = x
@@ -650,12 +648,12 @@ def apply(
650648
None, # output1_scale_scalar
651649
None, # output1_scale_gate_scalar
652650
None, # output2_scale_scalar
653-
self.num_experts,
651+
global_num_experts,
654652
top_k,
655653
None, # n_group
656654
None, # topk_group
657655
self.intermediate_size, # padded to multiple of 256
658-
0, # local_expert_offset
656+
layer.ep_rank * layer.local_num_experts, # local_expert_offset
659657
self.num_experts, # local num experts
660658
None,
661659
self._get_tile_tokens_dim(x, top_k),

0 commit comments

Comments
 (0)