3636from vllm .attention .backends .abstract import AttentionBackend
3737from vllm .attention .ops .common import pack_seq_triton , unpack_seq_triton
3838from vllm .compilation .decorators import support_torch_compile
39- from vllm .config import (CacheConfig , ParallelConfig , VllmConfig ,
39+ from vllm .config import (CacheConfig , ModelConfig , ParallelConfig , VllmConfig ,
4040 get_current_vllm_config )
4141from vllm .distributed import (get_ep_group , get_pp_group ,
4242 get_tensor_model_parallel_rank ,
@@ -133,6 +133,7 @@ class DeepseekV2MoE(nn.Module):
133133 def __init__ (
134134 self ,
135135 config : Union [DeepseekV2Config , DeepseekV3Config ],
136+ model_config : ModelConfig ,
136137 parallel_config : ParallelConfig ,
137138 quant_config : Optional [QuantizationConfig ] = None ,
138139 prefix : str = "" ,
@@ -182,27 +183,9 @@ def __init__(
182183 self .n_local_physical_experts )
183184
184185 if config .n_shared_experts is None :
185- self .experts = FusedMoE (
186- num_experts = config .n_routed_experts ,
187- top_k = config .num_experts_per_tok ,
188- hidden_size = config .hidden_size ,
189- intermediate_size = config .moe_intermediate_size ,
190- reduce_results = False ,
191- renormalize = config .norm_topk_prob ,
192- quant_config = quant_config ,
193- use_grouped_topk = True ,
194- num_expert_group = config .n_group ,
195- topk_group = config .topk_group ,
196- prefix = f"{ prefix } .experts" ,
197- scoring_func = config .scoring_func ,
198- # we do scaling outside, set factor to 1.0 to avoid double mul
199- routed_scaling_factor = 1.0 ,
200- e_score_correction_bias = self .gate .e_score_correction_bias ,
201- enable_eplb = self .enable_eplb ,
202- num_redundant_experts = self .n_redundant_experts ,
203- is_sequence_parallel = self .is_sequence_parallel ,
204- )
205186 self .shared_experts = None
187+ fused_output_scaling_factor = 1.0
188+ shared_output_scaling_factor = 1.0
206189 else :
207190 intermediate_size = (config .moe_intermediate_size *
208191 config .n_shared_experts )
@@ -213,31 +196,42 @@ def __init__(
213196 hidden_act = config .hidden_act ,
214197 quant_config = quant_config ,
215198 is_sequence_parallel = self .is_sequence_parallel ,
216- reduce_results = False ,
199+ reduce_results = False , # XXXXX
217200 prefix = f"{ prefix } .shared_experts" ,
218201 )
219202
220- self .experts = SharedFusedMoE (
221- shared_experts = self .shared_experts ,
222- num_experts = config .n_routed_experts ,
223- top_k = config .num_experts_per_tok ,
224- hidden_size = config .hidden_size ,
225- intermediate_size = config .moe_intermediate_size ,
226- reduce_results = False ,
227- renormalize = config .norm_topk_prob ,
228- quant_config = quant_config ,
229- use_grouped_topk = True ,
230- num_expert_group = config .n_group ,
231- topk_group = config .topk_group ,
232- prefix = f"{ prefix } .experts" ,
233- scoring_func = config .scoring_func ,
234- # we do scaling outside, set factor to 1.0 to avoid double mul
235- routed_scaling_factor = 1.0 ,
236- e_score_correction_bias = self .gate .e_score_correction_bias ,
237- enable_eplb = self .enable_eplb ,
238- num_redundant_experts = self .n_redundant_experts ,
239- is_sequence_parallel = self .is_sequence_parallel ,
240- )
203+ # Fix FP16 overflow
204+ # See DeepseekV2DecoderLayer for more details.
205+ if model_config .dtype != torch .float16 :
206+ fused_output_scaling_factor = self .routed_scaling_factor
207+ shared_output_scaling_factor = 1.0
208+ else :
209+ fused_output_scaling_factor = 1.0
210+ shared_output_scaling_factor = (1. /
211+ self .routed_scaling_factor )
212+
213+ self .experts = SharedFusedMoE (
214+ shared_experts = self .shared_experts ,
215+ fused_output_scaling_factor = fused_output_scaling_factor ,
216+ shared_output_scaling_factor = shared_output_scaling_factor ,
217+ num_experts = config .n_routed_experts ,
218+ top_k = config .num_experts_per_tok ,
219+ hidden_size = config .hidden_size ,
220+ intermediate_size = config .moe_intermediate_size ,
221+ renormalize = config .norm_topk_prob ,
222+ quant_config = quant_config ,
223+ use_grouped_topk = True ,
224+ num_expert_group = config .n_group ,
225+ topk_group = config .topk_group ,
226+ prefix = f"{ prefix } .experts" ,
227+ scoring_func = config .scoring_func ,
228+ # we do scaling outside, set factor to 1.0 to avoid double mul
229+ routed_scaling_factor = 1.0 ,
230+ e_score_correction_bias = self .gate .e_score_correction_bias ,
231+ enable_eplb = self .enable_eplb ,
232+ num_redundant_experts = self .n_redundant_experts ,
233+ is_sequence_parallel = self .is_sequence_parallel ,
234+ )
241235
242236 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
243237 num_tokens , hidden_dim = hidden_states .shape
@@ -253,36 +247,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
253247 # router_logits: (num_tokens, n_experts)
254248 router_logits , _ = self .gate (hidden_states )
255249
256- fused_moe_out = self .experts (hidden_states = hidden_states ,
257- router_logits = router_logits )
258-
259- if self .shared_experts is not None :
260- shared_output , final_hidden_states = fused_moe_out
261- else :
262- shared_output = None
263- final_hidden_states = fused_moe_out
264-
265- # Fix FP16 overflow
266- # See DeepseekV2DecoderLayer for more details.
267- if hidden_states .dtype != torch .float16 :
268- final_hidden_states *= self .routed_scaling_factor
269- elif self .shared_experts is not None :
270- assert shared_output is not None
271- shared_output *= (1. / self .routed_scaling_factor )
272-
273- if self .shared_experts is not None :
274- assert shared_output is not None
275- final_hidden_states += shared_output
250+ final_hidden_states = self .experts (hidden_states = hidden_states ,
251+ router_logits = router_logits )
276252
277253 if self .is_sequence_parallel :
278254 final_hidden_states = tensor_model_parallel_all_gather (
279255 final_hidden_states , 0 )
280256 final_hidden_states = final_hidden_states [:num_tokens ]
281- elif self .tp_size > 1 :
282- final_hidden_states = (
283- self .experts .maybe_all_reduce_tensor_model_parallel (
284- final_hidden_states ))
285257
258+ # TODO(bnell): why is this view needed?
286259 return final_hidden_states .view (num_tokens , hidden_dim )
287260
288261
@@ -1036,6 +1009,7 @@ def __init__(self,
10361009 and layer_idx % config .moe_layer_freq == 0 ):
10371010 self .mlp = DeepseekV2MoE (
10381011 config = config ,
1012+ model_config = model_config ,
10391013 parallel_config = parallel_config ,
10401014 quant_config = quant_config ,
10411015 prefix = f"{ prefix } .mlp" ,
0 commit comments