2828from typing import Any , Callable , Dict , List , Optional , Tuple , Union
2929
3030import torch
31- import torch .distributed as dist
3231import torch_npu
3332import vllm .envs as envs
3433from torch import nn
3736from vllm .config import CacheConfig , ModelConfig , VllmConfig
3837from vllm .distributed import (get_pp_group ,
3938 get_tensor_model_parallel_world_size ,
40- get_tp_group , tensor_model_parallel_all_reduce )
39+ get_tp_group )
4140from vllm .distributed .parallel_state import get_dp_group
4241from vllm .forward_context import get_forward_context
4342from vllm .model_executor .layers .activation import SiluAndMul
5453from vllm .model_executor .layers .vocab_parallel_embedding import (
5554 ParallelLMHead , VocabParallelEmbedding )
5655from vllm .model_executor .models .deepseek_v2 import \
57- DeepseekV2ForCausalLM # ruff: noqa: E501
56+ DeepseekV2ForCausalLM # noqa: E501
5857from vllm .model_executor .models .deepseek_v2 import \
59- yarn_get_mscale # ruff: noqa: E501
58+ yarn_get_mscale # noqa: E501
6059from vllm .model_executor .models .deepseek_v2 import (DeepseekV2Attention ,
6160 DeepseekV2DecoderLayer ,
6261 DeepseekV2MLAAttention )
6564 maybe_prefix )
6665from vllm .sequence import IntermediateTensors
6766
68- import vllm_ascend .envs as envs_ascend
6967from vllm_ascend .ascend_config import get_ascend_config
7068from vllm_ascend .distributed .parallel_state import get_ep_group
7169from vllm_ascend .ops .fused_moe import AscendFusedMoE
7270from vllm_ascend .quantization .quant_config import AscendLinearMethod
7371from vllm_ascend .quantization .w8a8_dynamic import AscendW8A8DynamicLinearMethod
7472from vllm_ascend .utils import dispose_tensor
7573
76- VLLM_ENABLE_MC2 : bool = envs_ascend .VLLM_ENABLE_MC2
77-
7874
7975class CustomDeepseekV2SiluAndMul (SiluAndMul ):
8076
@@ -239,9 +235,8 @@ def __init__(
239235
240236 ascend_config = get_ascend_config ()
241237 self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
242- # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
243238 self .enable_multistream_moe = \
244- ascend_config .torchair_graph_config .enable_multistream_moe and VLLM_ENABLE_MC2
239+ ascend_config .torchair_graph_config .enable_multistream_moe
245240
246241 self .gate = ReplicatedLinear (config .hidden_size ,
247242 config .n_routed_experts ,
@@ -311,23 +306,11 @@ def forward(
311306 enable_force_load_balance = False
312307 if hasattr (attn_metadata , 'with_prefill_across_dp' ):
313308 is_prefill = is_prefill or attn_metadata .with_prefill_across_dp
314- num_tokens , hidden_size = hidden_states . shape
315- old_hidden_states = hidden_states
309+
310+ old_hidden_states = hidden_states . clone ()
316311 use_separated_shared_experts = (self .shared_experts is not None
317312 and not self .enable_multistream_moe )
318313
319- if self .tp_size > 1 :
320- if (VLLM_ENABLE_MC2
321- and not is_prefill ) or not (self .torchair_graph_enabled or
322- self .ep_group .world_size == 1 ):
323- if num_tokens < self .tp_size :
324- hidden_states = nn .functional .pad (
325- hidden_states , (0 , 0 , 0 , self .tp_size - num_tokens ))
326- chunk_hidden_states = torch .tensor_split (hidden_states ,
327- self .tp_size ,
328- dim = 0 )
329- hidden_states = chunk_hidden_states [self .tp_rank ]
330-
331314 # router_logits: (num_tokens, n_experts)
332315 router_logits , _ = self .gate (hidden_states )
333316
@@ -348,23 +331,11 @@ def forward(
348331 experts_hidden_states [0 ] * self .routed_scaling_factor +
349332 experts_hidden_states [1 ])
350333
351- if self .tp_size > 1 :
352- if (VLLM_ENABLE_MC2
353- and not is_prefill ) or not (self .torchair_graph_enabled or
354- self .ep_group .world_size == 1 ):
355- dist .all_gather (list (chunk_hidden_states ), hidden_states ,
356- self .tp_group )
357- hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
358- if num_tokens < self .tp_size :
359- hidden_states = hidden_states [:num_tokens ]
360- else :
361- hidden_states = tensor_model_parallel_all_reduce (hidden_states )
362-
363334 if use_separated_shared_experts :
364335 hidden_states = hidden_states + self .shared_experts (
365336 old_hidden_states )
366337
367- return hidden_states . view ( num_tokens , hidden_size )
338+ return hidden_states
368339
369340
370341class CustomDeepseekV2MLAAttention (DeepseekV2MLAAttention ):
0 commit comments