@@ -1755,9 +1755,6 @@ def forward_impl(
17551755 self .dp_size > 1
17561756 and not self .moe_parallel_config .use_deepep_ht_kernels
17571757 and not self .moe_config .use_flashinfer_cutlass_kernels )
1758- if do_naive_dispatch_combine :
1759- hidden_states , router_logits = get_ep_group ().dispatch (
1760- hidden_states , router_logits )
17611758
17621759 # If there are shared experts but we are not using a modular kernel, the
17631760 # shared experts must be called here
@@ -1768,6 +1765,10 @@ def forward_impl(
17681765 else :
17691766 shared_output = None
17701767
1768+ if do_naive_dispatch_combine :
1769+ hidden_states , router_logits = get_ep_group ().dispatch (
1770+ hidden_states , router_logits )
1771+
17711772 # Matrix multiply.
17721773 final_hidden_states = self .quant_method .apply (
17731774 layer = self ,
@@ -1800,8 +1801,9 @@ def forward_impl(
18001801 final_hidden_states ,
18011802 )
18021803
1803- def reduce_output (states : torch .Tensor ) -> torch .Tensor :
1804- if do_naive_dispatch_combine :
1804+ def reduce_output (states : torch .Tensor ,
1805+ do_combine : bool = True ) -> torch .Tensor :
1806+ if do_naive_dispatch_combine and do_combine :
18051807 states = get_ep_group ().combine (states )
18061808
18071809 if self .reduce_results and (self .tp_size > 1 or self .ep_size > 1 ):
@@ -1810,10 +1812,11 @@ def reduce_output(states: torch.Tensor) -> torch.Tensor:
18101812 return states
18111813
18121814 if self .shared_experts is None :
1815+ assert not isinstance (final_hidden_states , tuple )
18131816 return reduce_output (final_hidden_states )
18141817 else :
18151818 return (
1816- reduce_output (final_hidden_states [0 ]),
1819+ reduce_output (final_hidden_states [0 ], do_combine = False ),
18171820 reduce_output (final_hidden_states [1 ]),
18181821 )
18191822
0 commit comments