@@ -212,6 +212,14 @@ def __init__(
212212 self .tp_group = get_tp_group ().device_group
213213 self .tp_rank = get_tp_group ().rank_in_group
214214
215+ self .params_dtype = torch .get_default_dtype ()
216+
217+ self .enable_graph_mode = False
218+ additional_config = get_current_vllm_config ().additional_config
219+ if additional_config :
220+ self .enable_graph_mode = additional_config .get (
221+ "enable_graph_mode" , False )
222+
215223 def forward (
216224 self ,
217225 hidden_states : torch .Tensor ,
@@ -228,52 +236,65 @@ def forward(
228236 else :
229237 is_prefill = attn_metadata .num_prefills > 0
230238 enable_force_load_balance = False
231- num_tokens , hidden_dim = hidden_states .shape
239+ if hasattr (attn_metadata , 'with_prefill_across_dp' ):
240+ is_prefill = is_prefill or attn_metadata .with_prefill_across_dp
241+
242+ num_tokens , hidden_size = hidden_states .shape
232243
233244 if self .n_shared_experts is not None :
234245 shared_output = self .shared_experts (hidden_states )
235246
236247 if self .tp_size > 1 :
237- # pass
238- num_tokens , hidden_size = hidden_states .shape
239- if num_tokens < self .tp_size :
240- target_size = self .tp_size
241- new_hidden_states = torch .empty ([target_size , hidden_size ],
242- dtype = hidden_states .dtype ,
243- device = hidden_states .device )
244- new_hidden_states [:num_tokens ] = hidden_states
245- hidden_states = new_hidden_states
246- chunk_hidden_states = torch .tensor_split (hidden_states ,
247- self .tp_size ,
248- dim = 0 )
249- local_hidden_states = chunk_hidden_states [self .tp_rank ]
250- else :
251- local_hidden_states = hidden_states
248+ if envs_ascend .VLLM_ENABLE_MC2 and not is_prefill :
249+ chunks = torch .chunk (hidden_states , self .tp_size , dim = 0 )
250+ hidden_states = chunks [self .tp_rank ]
251+ elif not self .enable_graph_mode :
252+ num_padding_tokens = (self .tp_size -
253+ num_tokens % self .tp_size ) % self .tp_size
254+ # Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
255+ if num_padding_tokens > 0 :
256+ hidden_states = nn .functional .pad (
257+ hidden_states , (0 , 0 , 0 , num_padding_tokens ))
258+ chunk_hidden_states = torch .tensor_split (hidden_states ,
259+ self .tp_size ,
260+ dim = 0 )
261+ hidden_states = chunk_hidden_states [self .tp_rank ]
252262
253263 # router_logits: (num_tokens, n_experts)
254- router_logits , _ = self .gate (local_hidden_states )
264+ router_logits , _ = self .gate (hidden_states )
255265
256- router_hidden_states = self .experts (
257- hidden_states = local_hidden_states ,
266+ hidden_states = self .experts (
267+ hidden_states = hidden_states ,
258268 router_logits = router_logits ,
259269 is_prefill = is_prefill ,
260270 top_k = CustomDeepseekV2MoE .top_k ,
261271 enable_force_load_balance = enable_force_load_balance ,
262272 ) * self .routed_scaling_factor
263273
264274 if self .tp_size > 1 :
265- dist .all_gather (list (chunk_hidden_states ), router_hidden_states ,
266- self .tp_group )
267- final_hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
268- if num_tokens < self .tp_size :
269- final_hidden_states = final_hidden_states [:num_tokens ]
270- else :
271- final_hidden_states = router_hidden_states
275+ if self .enable_graph_mode :
276+ if envs_ascend .VLLM_ENABLE_MC2 and not is_prefill :
277+ final_hidden_states = torch .zeros (
278+ [num_tokens , hidden_size ],
279+ dtype = self .params_dtype ,
280+ device = "npu" )
281+ dist .all_gather_into_tensor (final_hidden_states ,
282+ hidden_states , self .tp_group )
283+ hidden_states = final_hidden_states
284+ else :
285+ hidden_states = tensor_model_parallel_all_reduce (
286+ hidden_states )
287+ else :
288+ dist .all_gather (list (chunk_hidden_states ), hidden_states ,
289+ self .tp_group )
290+ hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
291+ if num_padding_tokens > 0 :
292+ hidden_states = hidden_states [:- num_padding_tokens ]
272293
273294 if shared_output is not None :
274- final_hidden_states = final_hidden_states + shared_output
295+ hidden_states = hidden_states + shared_output
275296
276- return final_hidden_states .view (num_tokens , hidden_dim )
297+ return hidden_states .view (num_tokens , hidden_size )
277298
278299
279300class CustomDeepseekV2MLAAttention (DeepseekV2MLAAttention ):
0 commit comments