3636from vllm .attention .backends .abstract import AttentionBackend
3737from vllm .attention .ops .common import pack_seq_triton , unpack_seq_triton
3838from vllm .compilation .decorators import support_torch_compile
39- from vllm .config import (CacheConfig , ModelConfig , ParallelConfig , VllmConfig ,
39+ from vllm .config import (CacheConfig , ParallelConfig , VllmConfig ,
4040 get_current_vllm_config )
4141from vllm .distributed import (get_ep_group , get_pp_group ,
4242 get_tensor_model_parallel_rank ,
@@ -133,7 +133,6 @@ class DeepseekV2MoE(nn.Module):
133133 def __init__ (
134134 self ,
135135 config : Union [DeepseekV2Config , DeepseekV3Config ],
136- model_config : ModelConfig ,
137136 parallel_config : ParallelConfig ,
138137 quant_config : Optional [QuantizationConfig ] = None ,
139138 prefix : str = "" ,
@@ -184,8 +183,6 @@ def __init__(
184183
185184 if config .n_shared_experts is None :
186185 self .shared_experts = None
187- fused_output_scaling_factor = 1.0
188- shared_output_scaling_factor = 1.0
189186 else :
190187 intermediate_size = (config .moe_intermediate_size *
191188 config .n_shared_experts )
@@ -196,28 +193,17 @@ def __init__(
196193 hidden_act = config .hidden_act ,
197194 quant_config = quant_config ,
198195 is_sequence_parallel = self .is_sequence_parallel ,
199- reduce_results = False , # XXXXX
196+ reduce_results = False ,
200197 prefix = f"{ prefix } .shared_experts" ,
201198 )
202199
203- # Fix FP16 overflow
204- # See DeepseekV2DecoderLayer for more details.
205- if model_config .dtype != torch .float16 :
206- fused_output_scaling_factor = self .routed_scaling_factor
207- shared_output_scaling_factor = 1.0
208- else :
209- fused_output_scaling_factor = 1.0
210- shared_output_scaling_factor = (1. /
211- self .routed_scaling_factor )
212-
213200 self .experts = SharedFusedMoE (
214201 shared_experts = self .shared_experts ,
215- fused_output_scaling_factor = fused_output_scaling_factor ,
216- shared_output_scaling_factor = shared_output_scaling_factor ,
217202 num_experts = config .n_routed_experts ,
218203 top_k = config .num_experts_per_tok ,
219204 hidden_size = config .hidden_size ,
220205 intermediate_size = config .moe_intermediate_size ,
206+ reduce_results = False ,
221207 renormalize = config .norm_topk_prob ,
222208 quant_config = quant_config ,
223209 use_grouped_topk = True ,
@@ -247,15 +233,36 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
247233 # router_logits: (num_tokens, n_experts)
248234 router_logits , _ = self .gate (hidden_states )
249235
250- final_hidden_states = self .experts (hidden_states = hidden_states ,
251- router_logits = router_logits )
236+ fused_moe_out = self .experts (hidden_states = hidden_states ,
237+ router_logits = router_logits )
238+
239+ if self .shared_experts is not None :
240+ shared_output , final_hidden_states = fused_moe_out
241+ else :
242+ shared_output = None
243+ final_hidden_states = fused_moe_out
244+
245+ # Fix FP16 overflow
246+ # See DeepseekV2DecoderLayer for more details.
247+ if hidden_states .dtype != torch .float16 :
248+ final_hidden_states *= self .routed_scaling_factor
249+ elif self .shared_experts is not None :
250+ assert shared_output is not None
251+ shared_output *= (1. / self .routed_scaling_factor )
252+
253+ if self .shared_experts is not None :
254+ assert shared_output is not None
255+ final_hidden_states += shared_output
252256
253257 if self .is_sequence_parallel :
254258 final_hidden_states = tensor_model_parallel_all_gather (
255259 final_hidden_states , 0 )
256260 final_hidden_states = final_hidden_states [:num_tokens ]
261+ elif self .tp_size > 1 :
262+ final_hidden_states = (
263+ self .experts .maybe_all_reduce_tensor_model_parallel (
264+ final_hidden_states ))
257265
258- # TODO(bnell): why is this view needed?
259266 return final_hidden_states .view (num_tokens , hidden_dim )
260267
261268
@@ -1009,7 +1016,6 @@ def __init__(self,
10091016 and layer_idx % config .moe_layer_freq == 0 ):
10101017 self .mlp = DeepseekV2MoE (
10111018 config = config ,
1012- model_config = model_config ,
10131019 parallel_config = parallel_config ,
10141020 quant_config = quant_config ,
10151021 prefix = f"{ prefix } .mlp" ,
0 commit comments