Skip to content

Commit b23fb78

Browse files
authored
[Bugfix] Fix for 24530. Fix naive all2all shared expert overlap. (#24538)
1 parent 561f38d commit b23fb78

File tree

1 file changed

+9
-6
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+9
-6
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,9 +1755,6 @@ def forward_impl(
17551755
self.dp_size > 1
17561756
and not self.moe_parallel_config.use_deepep_ht_kernels
17571757
and not self.moe_config.use_flashinfer_cutlass_kernels)
1758-
if do_naive_dispatch_combine:
1759-
hidden_states, router_logits = get_ep_group().dispatch(
1760-
hidden_states, router_logits)
17611758

17621759
# If there are shared experts but we are not using a modular kernel, the
17631760
# shared experts must be called here
@@ -1768,6 +1765,10 @@ def forward_impl(
17681765
else:
17691766
shared_output = None
17701767

1768+
if do_naive_dispatch_combine:
1769+
hidden_states, router_logits = get_ep_group().dispatch(
1770+
hidden_states, router_logits)
1771+
17711772
# Matrix multiply.
17721773
final_hidden_states = self.quant_method.apply(
17731774
layer=self,
@@ -1800,8 +1801,9 @@ def forward_impl(
18001801
final_hidden_states,
18011802
)
18021803

1803-
def reduce_output(states: torch.Tensor) -> torch.Tensor:
1804-
if do_naive_dispatch_combine:
1804+
def reduce_output(states: torch.Tensor,
1805+
do_combine: bool = True) -> torch.Tensor:
1806+
if do_naive_dispatch_combine and do_combine:
18051807
states = get_ep_group().combine(states)
18061808

18071809
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
@@ -1810,10 +1812,11 @@ def reduce_output(states: torch.Tensor) -> torch.Tensor:
18101812
return states
18111813

18121814
if self.shared_experts is None:
1815+
assert not isinstance(final_hidden_states, tuple)
18131816
return reduce_output(final_hidden_states)
18141817
else:
18151818
return (
1816-
reduce_output(final_hidden_states[0]),
1819+
reduce_output(final_hidden_states[0], do_combine=False),
18171820
reduce_output(final_hidden_states[1]),
18181821
)
18191822

0 commit comments

Comments
 (0)