Skip to content

Commit a31973f

Browse files
committed
[refactor] Refactoring AscendFusedMoE
Signed-off-by: zzzzwwjj <1183291235@qq.com>
1 parent 0d2074a commit a31973f

File tree

9 files changed

+124
-182
lines changed

9 files changed

+124
-182
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class AscendMLAMetadata:
136136
# For logging.
137137
num_input_tokens: int = 0 # Number of tokens including padding.
138138

139+
max_num_tokens_across_dp: int = 0
139140
with_prefill_across_dp: bool = False
140141

141142
query_lens: Optional[list[int]] = None
@@ -364,6 +365,7 @@ def build(
364365
common_attn_metadata: CommonAttentionMetadata,
365366
common_prefix_len: Optional[int] = None,
366367
graph_pad_size: int = -1,
368+
max_num_tokens_across_dp: int = 0,
367369
with_prefill_across_dp: bool = False,
368370
) -> AscendMLAMetadata:
369371
assert self._num_decodes + self._num_prefills == num_reqs
@@ -509,6 +511,7 @@ def build(
509511
query_start_loc=query_start_loc,
510512
block_tables=block_table,
511513
seq_lens=seq_lens,
514+
max_num_tokens_across_dp=max_num_tokens_across_dp,
512515
with_prefill_across_dp=with_prefill_across_dp,
513516
)
514517

vllm_ascend/envs.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,10 @@
5050
# value is None, which means the system default C compiler will be used.
5151
"C_COMPILER":
5252
lambda: os.getenv("C_COMPILER", None),
53-
# Whether to enable MC2 for DeepSeek. If not set, the default value is False.
54-
# MC2 is a fusion operator provided by Ascend to speed up computing and communication.
55-
# Find more detail here: https://www.hiascend.com/document/detail/zh/canncommercial/81RC1/developmentguide/opdevg/ascendcbestP/atlas_ascendc_best_practices_10_0043.html
56-
"VLLM_ENABLE_MC2":
57-
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
5853
# Whether to enable the topk optimization. It's disabled by default for experimental support
5954
# We'll make it enabled by default in the future.
6055
"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE":
6156
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE", '0'))),
62-
# Whether to use LCCL communication. If not set, the default value is False.
63-
"USING_LCCL_COM":
64-
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
6557
# The version of the Ascend chip. If not set, the default value is
6658
# ASCEND910B1. It's used for package building. Please make sure that the
6759
# version is correct.

vllm_ascend/models/deepseek_dbo.py

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@
5151
from vllm.model_executor.layers.vocab_parallel_embedding import (
5252
ParallelLMHead, VocabParallelEmbedding)
5353
from vllm.model_executor.models.deepseek_v2 import \
54-
DeepseekV2ForCausalLM # ruff: noqa: E501
54+
DeepseekV2ForCausalLM # noqa: E501
5555
from vllm.model_executor.models.deepseek_v2 import \
56-
yarn_get_mscale # ruff: noqa: E501
56+
yarn_get_mscale # noqa: E501
5757
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
5858
DeepseekV2DecoderLayer,
5959
DeepseekV2MLAAttention)
@@ -79,7 +79,6 @@
7979
from vllm_ascend.utils import dispose_tensor
8080

8181
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
82-
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
8382

8483

8584
class CustomDeepseekDBOMLP(CustomDeepseekV2MLP):
@@ -189,26 +188,8 @@ def forward(
189188
if hasattr(attn_metadata, 'with_prefill_across_dp'):
190189
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
191190

192-
num_tokens, hidden_size = hidden_states.shape
193-
194191
old_hidden_states = hidden_states.clone()
195192

196-
if self.tp_size > 1:
197-
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
198-
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
199-
hidden_states = chunks[self.tp_rank]
200-
elif not self.torchair_graph_enabled:
201-
num_padding_tokens = (self.tp_size -
202-
num_tokens % self.tp_size) % self.tp_size
203-
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
204-
if num_padding_tokens > 0:
205-
hidden_states = nn.functional.pad(
206-
hidden_states, (0, 0, 0, num_padding_tokens))
207-
chunk_hidden_states = torch.tensor_split(hidden_states,
208-
self.tp_size,
209-
dim=0)
210-
hidden_states = chunk_hidden_states[self.tp_rank]
211-
212193
# router_logits: (num_tokens, n_experts)
213194
router_logits, _ = self.gate(hidden_states)
214195

@@ -220,33 +201,13 @@ def forward(
220201
enable_force_load_balance=enable_force_load_balance,
221202
) * self.routed_scaling_factor
222203

223-
if self.tp_size > 1:
224-
if self.torchair_graph_enabled:
225-
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
226-
final_hidden_states = torch.zeros(
227-
[num_tokens, hidden_size],
228-
dtype=self.params_dtype,
229-
device="npu")
230-
dist.all_gather_into_tensor(final_hidden_states,
231-
hidden_states, self.tp_group)
232-
hidden_states = final_hidden_states
233-
else:
234-
hidden_states = tensor_model_parallel_all_reduce(
235-
hidden_states)
236-
else:
237-
dist.all_gather(list(chunk_hidden_states), hidden_states,
238-
self.tp_group)
239-
hidden_states = torch.cat(chunk_hidden_states, dim=0)
240-
if num_padding_tokens > 0:
241-
hidden_states = hidden_states[:-num_padding_tokens]
242-
243204
if self.n_shared_experts is not None:
244205
shared_output = self.shared_experts(old_hidden_states)
245206

246207
if shared_output is not None:
247208
hidden_states = hidden_states + shared_output
248209

249-
return hidden_states.view(num_tokens, hidden_size)
210+
return hidden_states
250211

251212
# ----------------------------------------- TBO-related --------------------------------------------
252213
def _forward_ms_op_shared_expert(

vllm_ascend/models/deepseek_v2.py

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2929

3030
import torch
31-
import torch.distributed as dist
3231
import torch_npu
3332
import vllm.envs as envs
3433
from torch import nn
@@ -37,7 +36,7 @@
3736
from vllm.config import CacheConfig, ModelConfig, VllmConfig
3837
from vllm.distributed import (get_pp_group,
3938
get_tensor_model_parallel_world_size,
40-
get_tp_group, tensor_model_parallel_all_reduce)
39+
get_tp_group)
4140
from vllm.distributed.parallel_state import get_dp_group
4241
from vllm.forward_context import get_forward_context
4342
from vllm.model_executor.layers.activation import SiluAndMul
@@ -54,9 +53,9 @@
5453
from vllm.model_executor.layers.vocab_parallel_embedding import (
5554
ParallelLMHead, VocabParallelEmbedding)
5655
from vllm.model_executor.models.deepseek_v2 import \
57-
DeepseekV2ForCausalLM # ruff: noqa: E501
56+
DeepseekV2ForCausalLM # noqa: E501
5857
from vllm.model_executor.models.deepseek_v2 import \
59-
yarn_get_mscale # ruff: noqa: E501
58+
yarn_get_mscale # noqa: E501
6059
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
6160
DeepseekV2DecoderLayer,
6261
DeepseekV2MLAAttention)
@@ -65,7 +64,6 @@
6564
maybe_prefix)
6665
from vllm.sequence import IntermediateTensors
6766

68-
import vllm_ascend.envs as envs_ascend
6967
from vllm_ascend.ascend_config import get_ascend_config
7068
from vllm_ascend.distributed.parallel_state import get_ep_group
7169
from vllm_ascend.ops.fused_moe import AscendFusedMoE
@@ -74,8 +72,6 @@
7472
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
7573
npu_wait_tensor)
7674

77-
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
78-
7975

8076
class CustomDeepseekV2SiluAndMul(SiluAndMul):
8177

@@ -240,9 +236,8 @@ def __init__(
240236

241237
ascend_config = get_ascend_config()
242238
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
243-
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
244239
self.enable_multistream_moe = \
245-
ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2
240+
ascend_config.torchair_graph_config.enable_multistream_moe
246241

247242
self.gate = ReplicatedLinear(config.hidden_size,
248243
config.n_routed_experts,
@@ -312,23 +307,11 @@ def forward(
312307
enable_force_load_balance = False
313308
if hasattr(attn_metadata, 'with_prefill_across_dp'):
314309
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
315-
num_tokens, hidden_size = hidden_states.shape
316-
old_hidden_states = hidden_states
310+
311+
old_hidden_states = hidden_states.clone()
317312
use_separated_shared_experts = (self.shared_experts is not None
318313
and not self.enable_multistream_moe)
319314

320-
if self.tp_size > 1:
321-
if (VLLM_ENABLE_MC2
322-
and not is_prefill) or not (self.torchair_graph_enabled or
323-
self.ep_group.world_size == 1):
324-
if num_tokens < self.tp_size:
325-
hidden_states = nn.functional.pad(
326-
hidden_states, (0, 0, 0, self.tp_size - num_tokens))
327-
chunk_hidden_states = torch.tensor_split(hidden_states,
328-
self.tp_size,
329-
dim=0)
330-
hidden_states = chunk_hidden_states[self.tp_rank]
331-
332315
# router_logits: (num_tokens, n_experts)
333316
router_logits, _ = self.gate(hidden_states)
334317

@@ -349,23 +332,11 @@ def forward(
349332
experts_hidden_states[0] * self.routed_scaling_factor +
350333
experts_hidden_states[1])
351334

352-
if self.tp_size > 1:
353-
if (VLLM_ENABLE_MC2
354-
and not is_prefill) or not (self.torchair_graph_enabled or
355-
self.ep_group.world_size == 1):
356-
dist.all_gather(list(chunk_hidden_states), hidden_states,
357-
self.tp_group)
358-
hidden_states = torch.cat(chunk_hidden_states, dim=0)
359-
if num_tokens < self.tp_size:
360-
hidden_states = hidden_states[:num_tokens]
361-
else:
362-
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
363-
364335
if use_separated_shared_experts:
365336
hidden_states = hidden_states + self.shared_experts(
366337
old_hidden_states)
367338

368-
return hidden_states.view(num_tokens, hidden_size)
339+
return hidden_states
369340

370341

371342
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):

0 commit comments

Comments
 (0)