2828from typing import Any , Callable , Dict , List , Optional , Tuple , Union
2929
3030import torch
31- import torch .distributed as dist
3231import torch_npu
3332import vllm .envs as envs
3433from torch import nn
3736from vllm .config import CacheConfig , ModelConfig , VllmConfig
3837from vllm .distributed import (get_pp_group ,
3938 get_tensor_model_parallel_world_size ,
40- get_tp_group , tensor_model_parallel_all_reduce )
39+ get_tp_group )
4140from vllm .distributed .parallel_state import get_dp_group
4241from vllm .forward_context import get_forward_context
4342from vllm .model_executor .layers .activation import SiluAndMul
5453from vllm .model_executor .layers .vocab_parallel_embedding import (
5554 ParallelLMHead , VocabParallelEmbedding )
5655from vllm .model_executor .models .deepseek_v2 import \
57- DeepseekV2ForCausalLM # ruff: noqa: E501
56+ DeepseekV2ForCausalLM # noqa: E501
5857from vllm .model_executor .models .deepseek_v2 import \
59- yarn_get_mscale # ruff: noqa: E501
58+ yarn_get_mscale # noqa: E501
6059from vllm .model_executor .models .deepseek_v2 import (DeepseekV2Attention ,
6160 DeepseekV2DecoderLayer ,
6261 DeepseekV2MLAAttention )
6564 maybe_prefix )
6665from vllm .sequence import IntermediateTensors
6766
68- import vllm_ascend .envs as envs_ascend
6967from vllm_ascend .ascend_config import get_ascend_config
7068from vllm_ascend .distributed .parallel_state import get_ep_group
7169from vllm_ascend .ops .fused_moe import AscendFusedMoE
7472from vllm_ascend .utils import (dispose_tensor , npu_stream_switch ,
7573 npu_wait_tensor )
7674
77- VLLM_ENABLE_MC2 : bool = envs_ascend .VLLM_ENABLE_MC2
78-
7975
8076class CustomDeepseekV2SiluAndMul (SiluAndMul ):
8177
@@ -240,9 +236,8 @@ def __init__(
240236
241237 ascend_config = get_ascend_config ()
242238 self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
243- # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
244239 self .enable_multistream_moe = \
245- ascend_config .torchair_graph_config .enable_multistream_moe and VLLM_ENABLE_MC2
240+ ascend_config .torchair_graph_config .enable_multistream_moe
246241
247242 self .gate = ReplicatedLinear (config .hidden_size ,
248243 config .n_routed_experts ,
@@ -312,23 +307,11 @@ def forward(
312307 enable_force_load_balance = False
313308 if hasattr (attn_metadata , 'with_prefill_across_dp' ):
314309 is_prefill = is_prefill or attn_metadata .with_prefill_across_dp
315- num_tokens , hidden_size = hidden_states . shape
316- old_hidden_states = hidden_states
310+
311+ old_hidden_states = hidden_states . clone ()
317312 use_separated_shared_experts = (self .shared_experts is not None
318313 and not self .enable_multistream_moe )
319314
320- if self .tp_size > 1 :
321- if (VLLM_ENABLE_MC2
322- and not is_prefill ) or not (self .torchair_graph_enabled or
323- self .ep_group .world_size == 1 ):
324- if num_tokens < self .tp_size :
325- hidden_states = nn .functional .pad (
326- hidden_states , (0 , 0 , 0 , self .tp_size - num_tokens ))
327- chunk_hidden_states = torch .tensor_split (hidden_states ,
328- self .tp_size ,
329- dim = 0 )
330- hidden_states = chunk_hidden_states [self .tp_rank ]
331-
332315 # router_logits: (num_tokens, n_experts)
333316 router_logits , _ = self .gate (hidden_states )
334317
@@ -349,23 +332,11 @@ def forward(
349332 experts_hidden_states [0 ] * self .routed_scaling_factor +
350333 experts_hidden_states [1 ])
351334
352- if self .tp_size > 1 :
353- if (VLLM_ENABLE_MC2
354- and not is_prefill ) or not (self .torchair_graph_enabled or
355- self .ep_group .world_size == 1 ):
356- dist .all_gather (list (chunk_hidden_states ), hidden_states ,
357- self .tp_group )
358- hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
359- if num_tokens < self .tp_size :
360- hidden_states = hidden_states [:num_tokens ]
361- else :
362- hidden_states = tensor_model_parallel_all_reduce (hidden_states )
363-
364335 if use_separated_shared_experts :
365336 hidden_states = hidden_states + self .shared_experts (
366337 old_hidden_states )
367338
368- return hidden_states . view ( num_tokens , hidden_size )
339+ return hidden_states
369340
370341
371342class CustomDeepseekV2MLAAttention (DeepseekV2MLAAttention ):
0 commit comments