Skip to content

Commit 20e3bb6

Browse files
committed
[refactor] Refactoring AscendFusedMoE
Signed-off-by: zzzzwwjj <1183291235@qq.com>
1 parent 4153a50 commit 20e3bb6

File tree

9 files changed

+124
-182
lines changed

9 files changed

+124
-182
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ class AscendMLAMetadata:
121121
# For logging.
122122
num_input_tokens: int = 0 # Number of tokens including padding.
123123

124+
max_num_tokens_across_dp: int = 0
124125
with_prefill_across_dp: bool = False
125126

126127
query_lens: Optional[list[int]] = None
@@ -324,6 +325,7 @@ def build(
324325
common_attn_metadata: CommonAttentionMetadata,
325326
common_prefix_len: Optional[int] = None,
326327
graph_pad_size: int = -1,
328+
max_num_tokens_across_dp: int = 0,
327329
with_prefill_across_dp: bool = False,
328330
) -> AscendMLAMetadata:
329331
assert self._num_decodes + self._num_prefills == num_reqs
@@ -432,6 +434,7 @@ def build(
432434
query_start_loc=query_start_loc,
433435
block_tables=block_table,
434436
seq_lens=seq_lens,
437+
max_num_tokens_across_dp=max_num_tokens_across_dp,
435438
with_prefill_across_dp=with_prefill_across_dp,
436439
)
437440

vllm_ascend/envs.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,10 @@
5050
# value is None, which means the system default C compiler will be used.
5151
"C_COMPILER":
5252
lambda: os.getenv("C_COMPILER", None),
53-
# Whether to enable MC2 for DeepSeek. If not set, the default value is False.
54-
# MC2 is a fusion operator provided by Ascend to speed up computing and communication.
55-
# Find more detail here: https://www.hiascend.com/document/detail/zh/canncommercial/81RC1/developmentguide/opdevg/ascendcbestP/atlas_ascendc_best_practices_10_0043.html
56-
"VLLM_ENABLE_MC2":
57-
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
5853
# Whether to enable the topk optimization. It's disabled by default for experimental support
5954
# We'll make it enabled by default in the future.
6055
"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE":
6156
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE", '0'))),
62-
# Whether to use LCCL communication. If not set, the default value is False.
63-
"USING_LCCL_COM":
64-
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
6557
# The version of the Ascend chip. If not set, the default value is
6658
# ASCEND910B1. It's used for package building. Please make sure that the
6759
# version is correct.

vllm_ascend/models/deepseek_dbo.py

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@
5151
from vllm.model_executor.layers.vocab_parallel_embedding import (
5252
ParallelLMHead, VocabParallelEmbedding)
5353
from vllm.model_executor.models.deepseek_v2 import \
54-
DeepseekV2ForCausalLM # ruff: noqa: E501
54+
DeepseekV2ForCausalLM # noqa: E501
5555
from vllm.model_executor.models.deepseek_v2 import \
56-
yarn_get_mscale # ruff: noqa: E501
56+
yarn_get_mscale # noqa: E501
5757
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
5858
DeepseekV2DecoderLayer,
5959
DeepseekV2MLAAttention)
@@ -79,7 +79,6 @@
7979
from vllm_ascend.utils import dispose_tensor
8080

8181
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
82-
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
8382

8483

8584
class CustomDeepseekDBOMLP(CustomDeepseekV2MLP):
@@ -189,26 +188,8 @@ def forward(
189188
if hasattr(attn_metadata, 'with_prefill_across_dp'):
190189
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
191190

192-
num_tokens, hidden_size = hidden_states.shape
193-
194191
old_hidden_states = hidden_states.clone()
195192

196-
if self.tp_size > 1:
197-
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
198-
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
199-
hidden_states = chunks[self.tp_rank]
200-
elif not self.torchair_graph_enabled:
201-
num_padding_tokens = (self.tp_size -
202-
num_tokens % self.tp_size) % self.tp_size
203-
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
204-
if num_padding_tokens > 0:
205-
hidden_states = nn.functional.pad(
206-
hidden_states, (0, 0, 0, num_padding_tokens))
207-
chunk_hidden_states = torch.tensor_split(hidden_states,
208-
self.tp_size,
209-
dim=0)
210-
hidden_states = chunk_hidden_states[self.tp_rank]
211-
212193
# router_logits: (num_tokens, n_experts)
213194
router_logits, _ = self.gate(hidden_states)
214195

@@ -220,33 +201,13 @@ def forward(
220201
enable_force_load_balance=enable_force_load_balance,
221202
) * self.routed_scaling_factor
222203

223-
if self.tp_size > 1:
224-
if self.torchair_graph_enabled:
225-
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
226-
final_hidden_states = torch.zeros(
227-
[num_tokens, hidden_size],
228-
dtype=self.params_dtype,
229-
device="npu")
230-
dist.all_gather_into_tensor(final_hidden_states,
231-
hidden_states, self.tp_group)
232-
hidden_states = final_hidden_states
233-
else:
234-
hidden_states = tensor_model_parallel_all_reduce(
235-
hidden_states)
236-
else:
237-
dist.all_gather(list(chunk_hidden_states), hidden_states,
238-
self.tp_group)
239-
hidden_states = torch.cat(chunk_hidden_states, dim=0)
240-
if num_padding_tokens > 0:
241-
hidden_states = hidden_states[:-num_padding_tokens]
242-
243204
if self.n_shared_experts is not None:
244205
shared_output = self.shared_experts(old_hidden_states)
245206

246207
if shared_output is not None:
247208
hidden_states = hidden_states + shared_output
248209

249-
return hidden_states.view(num_tokens, hidden_size)
210+
return hidden_states
250211

251212
# ----------------------------------------- TBO-related --------------------------------------------
252213
def _forward_ms_op_shared_expert(

vllm_ascend/models/deepseek_v2.py

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2929

3030
import torch
31-
import torch.distributed as dist
3231
import torch_npu
3332
import vllm.envs as envs
3433
from torch import nn
@@ -37,7 +36,7 @@
3736
from vllm.config import CacheConfig, ModelConfig, VllmConfig
3837
from vllm.distributed import (get_pp_group,
3938
get_tensor_model_parallel_world_size,
40-
get_tp_group, tensor_model_parallel_all_reduce)
39+
get_tp_group)
4140
from vllm.distributed.parallel_state import get_dp_group
4241
from vllm.forward_context import get_forward_context
4342
from vllm.model_executor.layers.activation import SiluAndMul
@@ -54,9 +53,9 @@
5453
from vllm.model_executor.layers.vocab_parallel_embedding import (
5554
ParallelLMHead, VocabParallelEmbedding)
5655
from vllm.model_executor.models.deepseek_v2 import \
57-
DeepseekV2ForCausalLM # ruff: noqa: E501
56+
DeepseekV2ForCausalLM # noqa: E501
5857
from vllm.model_executor.models.deepseek_v2 import \
59-
yarn_get_mscale # ruff: noqa: E501
58+
yarn_get_mscale # noqa: E501
6059
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
6160
DeepseekV2DecoderLayer,
6261
DeepseekV2MLAAttention)
@@ -65,16 +64,13 @@
6564
maybe_prefix)
6665
from vllm.sequence import IntermediateTensors
6766

68-
import vllm_ascend.envs as envs_ascend
6967
from vllm_ascend.ascend_config import get_ascend_config
7068
from vllm_ascend.distributed.parallel_state import get_ep_group
7169
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7270
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7371
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
7472
from vllm_ascend.utils import dispose_tensor
7573

76-
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
77-
7874

7975
class CustomDeepseekV2SiluAndMul(SiluAndMul):
8076

@@ -239,9 +235,8 @@ def __init__(
239235

240236
ascend_config = get_ascend_config()
241237
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
242-
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
243238
self.enable_multistream_moe = \
244-
ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2
239+
ascend_config.torchair_graph_config.enable_multistream_moe
245240

246241
self.gate = ReplicatedLinear(config.hidden_size,
247242
config.n_routed_experts,
@@ -311,23 +306,11 @@ def forward(
311306
enable_force_load_balance = False
312307
if hasattr(attn_metadata, 'with_prefill_across_dp'):
313308
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
314-
num_tokens, hidden_size = hidden_states.shape
315-
old_hidden_states = hidden_states
309+
310+
old_hidden_states = hidden_states.clone()
316311
use_separated_shared_experts = (self.shared_experts is not None
317312
and not self.enable_multistream_moe)
318313

319-
if self.tp_size > 1:
320-
if (VLLM_ENABLE_MC2
321-
and not is_prefill) or not (self.torchair_graph_enabled or
322-
self.ep_group.world_size == 1):
323-
if num_tokens < self.tp_size:
324-
hidden_states = nn.functional.pad(
325-
hidden_states, (0, 0, 0, self.tp_size - num_tokens))
326-
chunk_hidden_states = torch.tensor_split(hidden_states,
327-
self.tp_size,
328-
dim=0)
329-
hidden_states = chunk_hidden_states[self.tp_rank]
330-
331314
# router_logits: (num_tokens, n_experts)
332315
router_logits, _ = self.gate(hidden_states)
333316

@@ -348,23 +331,11 @@ def forward(
348331
experts_hidden_states[0] * self.routed_scaling_factor +
349332
experts_hidden_states[1])
350333

351-
if self.tp_size > 1:
352-
if (VLLM_ENABLE_MC2
353-
and not is_prefill) or not (self.torchair_graph_enabled or
354-
self.ep_group.world_size == 1):
355-
dist.all_gather(list(chunk_hidden_states), hidden_states,
356-
self.tp_group)
357-
hidden_states = torch.cat(chunk_hidden_states, dim=0)
358-
if num_tokens < self.tp_size:
359-
hidden_states = hidden_states[:num_tokens]
360-
else:
361-
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
362-
363334
if use_separated_shared_experts:
364335
hidden_states = hidden_states + self.shared_experts(
365336
old_hidden_states)
366337

367-
return hidden_states.view(num_tokens, hidden_size)
338+
return hidden_states
368339

369340

370341
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):

0 commit comments

Comments
 (0)