3636from vllm .attention import Attention , AttentionMetadata
3737from vllm .config import (CacheConfig , ModelConfig , VllmConfig ,
3838 get_current_vllm_config )
39- from vllm .distributed import (get_dp_group , get_pp_group ,
39+ from vllm .distributed import (get_pp_group ,
4040 get_tensor_model_parallel_world_size ,
4141 get_tp_group , tensor_model_parallel_all_reduce )
4242from vllm .forward_context import get_forward_context
@@ -205,17 +205,16 @@ def __init__(
205205 )
206206 CustomDeepseekV2MoE .top_k = config .num_experts_per_tok
207207
208- vllm_config = get_current_vllm_config ()
209- self .dp_size = get_dp_group ().world_size
210- batch_size = vllm_config .scheduler_config .max_num_seqs
211-
212- params_dtype = torch .get_default_dtype ()
213- self .final_hidden_states = torch .zeros (
214- [batch_size , config .hidden_size ], dtype = params_dtype , device = "npu" )
208+ self .params_dtype = torch .get_default_dtype ()
209+ self .tp_rank_in_group = get_tp_group ().rank_in_group
215210 self .tp_group = get_tp_group ().device_group
216211
217- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
218- attn_metadata = get_forward_context ().attn_metadata
212+ def forward (
213+ self ,
214+ hidden_states : torch .Tensor ,
215+ attn_metadata : Optional [AttentionMetadata ] = None ) -> torch .Tensor :
216+ if attn_metadata is None :
217+ attn_metadata = get_forward_context ().attn_metadata
219218 if attn_metadata is None :
220219 # for profile run
221220 is_prefill = True
@@ -224,34 +223,36 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
224223 num_tokens , hidden_dim = hidden_states .shape
225224 hidden_states = hidden_states .view (- 1 , hidden_dim )
226225
226+ if self .n_shared_experts is not None :
227+ shared_output = self .shared_experts (hidden_states )
228+
227229 if (self .tp_size > 1 and VLLM_ENABLE_MC2 and not is_prefill ):
228- chunks = torch .chunk (hidden_states ,
229- get_tp_group ().world_size ,
230- dim = 0 )
231- hidden_states = chunks [get_tp_group ().rank_in_group ]
230+ chunks = torch .chunk (hidden_states , self .tp_size , dim = 0 )
231+ hidden_states = chunks [self .tp_rank_in_group ]
232232
233233 # router_logits: (num_tokens, n_experts)
234234 router_logits , _ = self .gate (hidden_states )
235235
236- final_hidden_states = self .experts (
236+ hidden_states = self .experts (
237237 hidden_states = hidden_states ,
238238 router_logits = router_logits ,
239239 is_prefill = is_prefill ,
240240 top_k = CustomDeepseekV2MoE .top_k ) * self .routed_scaling_factor
241241
242242 if self .tp_size > 1 :
243243 if VLLM_ENABLE_MC2 and not is_prefill :
244- dist .all_gather_into_tensor (self .final_hidden_states ,
245- final_hidden_states , self .tp_group )
246- final_hidden_states = self .final_hidden_states
244+ final_hidden_states = torch .zeros ([num_tokens , hidden_dim ],
245+ dtype = self .params_dtype ,
246+ device = "npu" )
247+ dist .all_gather_into_tensor (final_hidden_states , hidden_states ,
248+ self .tp_group )
249+ hidden_states = final_hidden_states
247250 else :
248- final_hidden_states = tensor_model_parallel_all_reduce (
249- final_hidden_states )
251+ hidden_states = tensor_model_parallel_all_reduce (hidden_states )
250252 if self .n_shared_experts is not None :
251- shared_output = self .shared_experts (hidden_states )
252- final_hidden_states = final_hidden_states + shared_output
253+ hidden_states = hidden_states + shared_output
253254
254- return final_hidden_states .view (num_tokens , hidden_dim )
255+ return hidden_states .view (num_tokens , hidden_dim )
255256
256257
257258class CustomDeepseekV2MLAAttention (DeepseekV2MLAAttention ):
@@ -524,7 +525,11 @@ def forward(
524525 # Fully Connected
525526 hidden_states , residual = self .post_attention_layernorm (
526527 hidden_states , residual )
527- hidden_states = self .mlp (hidden_states )
528+
529+ if isinstance (self .mlp , CustomDeepseekV2MoE ):
530+ hidden_states = self .mlp (hidden_states , attn_metadata )
531+ else :
532+ hidden_states = self .mlp (hidden_states )
528533
529534 if isinstance (
530535 self .mlp ,
0 commit comments