Skip to content

Commit 05010a7

Browse files
zzzzwwjjganyi1996ppo
authored andcommitted
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> This PR is used for resolved [issue 1147](#1147) 1. Move fused_moe code into one file `fused_moe.py`. 2. Integrate branch conditions into function `get_fused_moe_state`. <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> 1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this env is useless, we can make judgments based on the current scenario without this env, it will only increase complexity. 2. This PR has removed the env `USING_LCCL_COM`, because this env has already expired. 3. `additional_config.expert_tensor_parallel_size` has already expired, and now we also use parameter `enable_expert_parallel`, consistent with the vLLM. <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: zzzzwwjj <1183291235@qq.com>
1 parent d798125 commit 05010a7

File tree

9 files changed

+192
-209
lines changed

9 files changed

+192
-209
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class AscendMLAMetadata:
136136
# For logging.
137137
num_input_tokens: int = 0 # Number of tokens including padding.
138138

139+
max_num_tokens_across_dp: int = 0
139140
with_prefill_across_dp: bool = False
140141

141142
query_lens: Optional[list[int]] = None
@@ -364,6 +365,7 @@ def build(
364365
common_attn_metadata: CommonAttentionMetadata,
365366
common_prefix_len: Optional[int] = None,
366367
graph_pad_size: int = -1,
368+
max_num_tokens_across_dp: int = 0,
367369
with_prefill_across_dp: bool = False,
368370
) -> AscendMLAMetadata:
369371
assert self._num_decodes + self._num_prefills == num_reqs
@@ -509,6 +511,7 @@ def build(
509511
query_start_loc=query_start_loc,
510512
block_tables=block_table,
511513
seq_lens=seq_lens,
514+
max_num_tokens_across_dp=max_num_tokens_across_dp,
512515
with_prefill_across_dp=with_prefill_across_dp,
513516
)
514517

vllm_ascend/envs.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,10 @@
5050
# value is None, which means the system default C compiler will be used.
5151
"C_COMPILER":
5252
lambda: os.getenv("C_COMPILER", None),
53-
# Whether to enable MC2 for DeepSeek. If not set, the default value is False.
54-
# MC2 is a fusion operator provided by Ascend to speed up computing and communication.
55-
# Find more detail here: https://www.hiascend.com/document/detail/zh/canncommercial/81RC1/developmentguide/opdevg/ascendcbestP/atlas_ascendc_best_practices_10_0043.html
56-
"VLLM_ENABLE_MC2":
57-
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
5853
# Whether to enable the topk optimization. It's disabled by default for experimental support
5954
# We'll make it enabled by default in the future.
6055
"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE":
6156
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE", '0'))),
62-
# Whether to use LCCL communication. If not set, the default value is False.
63-
"USING_LCCL_COM":
64-
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
6557
# The version of the Ascend chip. If not set, the default value is
6658
# ASCEND910B1. It's used for package building. Please make sure that the
6759
# version is correct.

vllm_ascend/models/deepseek_dbo.py

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@
5151
from vllm.model_executor.layers.vocab_parallel_embedding import (
5252
ParallelLMHead, VocabParallelEmbedding)
5353
from vllm.model_executor.models.deepseek_v2 import \
54-
DeepseekV2ForCausalLM # ruff: noqa: E501
54+
DeepseekV2ForCausalLM # noqa: E501
5555
from vllm.model_executor.models.deepseek_v2 import \
56-
yarn_get_mscale # ruff: noqa: E501
56+
yarn_get_mscale # noqa: E501
5757
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
5858
DeepseekV2DecoderLayer,
5959
DeepseekV2MLAAttention)
@@ -79,7 +79,6 @@
7979
from vllm_ascend.utils import dispose_tensor
8080

8181
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
82-
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
8382

8483

8584
class CustomDeepseekDBOMLP(CustomDeepseekV2MLP):
@@ -189,26 +188,8 @@ def forward(
189188
if hasattr(attn_metadata, 'with_prefill_across_dp'):
190189
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
191190

192-
num_tokens, hidden_size = hidden_states.shape
193-
194191
old_hidden_states = hidden_states.clone()
195192

196-
if self.tp_size > 1:
197-
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
198-
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
199-
hidden_states = chunks[self.tp_rank]
200-
elif not self.torchair_graph_enabled:
201-
num_padding_tokens = (self.tp_size -
202-
num_tokens % self.tp_size) % self.tp_size
203-
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
204-
if num_padding_tokens > 0:
205-
hidden_states = nn.functional.pad(
206-
hidden_states, (0, 0, 0, num_padding_tokens))
207-
chunk_hidden_states = torch.tensor_split(hidden_states,
208-
self.tp_size,
209-
dim=0)
210-
hidden_states = chunk_hidden_states[self.tp_rank]
211-
212193
# router_logits: (num_tokens, n_experts)
213194
router_logits, _ = self.gate(hidden_states)
214195

@@ -220,33 +201,13 @@ def forward(
220201
enable_force_load_balance=enable_force_load_balance,
221202
) * self.routed_scaling_factor
222203

223-
if self.tp_size > 1:
224-
if self.torchair_graph_enabled:
225-
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
226-
final_hidden_states = torch.zeros(
227-
[num_tokens, hidden_size],
228-
dtype=self.params_dtype,
229-
device="npu")
230-
dist.all_gather_into_tensor(final_hidden_states,
231-
hidden_states, self.tp_group)
232-
hidden_states = final_hidden_states
233-
else:
234-
hidden_states = tensor_model_parallel_all_reduce(
235-
hidden_states)
236-
else:
237-
dist.all_gather(list(chunk_hidden_states), hidden_states,
238-
self.tp_group)
239-
hidden_states = torch.cat(chunk_hidden_states, dim=0)
240-
if num_padding_tokens > 0:
241-
hidden_states = hidden_states[:-num_padding_tokens]
242-
243204
if self.n_shared_experts is not None:
244205
shared_output = self.shared_experts(old_hidden_states)
245206

246207
if shared_output is not None:
247208
hidden_states = hidden_states + shared_output
248209

249-
return hidden_states.view(num_tokens, hidden_size)
210+
return hidden_states
250211

251212
# ----------------------------------------- TBO-related --------------------------------------------
252213
def _forward_ms_op_shared_expert(

vllm_ascend/models/deepseek_v2.py

Lines changed: 9 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2929

3030
import torch
31-
import torch.distributed as dist
3231
import torch_npu
3332
import vllm.envs as envs
3433
from torch import nn
@@ -37,7 +36,7 @@
3736
from vllm.config import CacheConfig, ModelConfig, VllmConfig
3837
from vllm.distributed import (get_pp_group,
3938
get_tensor_model_parallel_world_size,
40-
get_tp_group, tensor_model_parallel_all_reduce)
39+
get_tp_group)
4140
from vllm.distributed.parallel_state import get_dp_group
4241
from vllm.forward_context import get_forward_context
4342
from vllm.model_executor.layers.activation import SiluAndMul
@@ -54,9 +53,9 @@
5453
from vllm.model_executor.layers.vocab_parallel_embedding import (
5554
ParallelLMHead, VocabParallelEmbedding)
5655
from vllm.model_executor.models.deepseek_v2 import \
57-
DeepseekV2ForCausalLM # ruff: noqa: E501
56+
DeepseekV2ForCausalLM # noqa: E501
5857
from vllm.model_executor.models.deepseek_v2 import \
59-
yarn_get_mscale # ruff: noqa: E501
58+
yarn_get_mscale # noqa: E501
6059
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
6160
DeepseekV2DecoderLayer,
6261
DeepseekV2MLAAttention)
@@ -65,7 +64,6 @@
6564
maybe_prefix)
6665
from vllm.sequence import IntermediateTensors
6766

68-
import vllm_ascend.envs as envs_ascend
6967
from vllm_ascend.ascend_config import get_ascend_config
7068
from vllm_ascend.distributed.parallel_state import get_ep_group
7169
from vllm_ascend.ops.fused_moe import AscendFusedMoE
@@ -74,8 +72,6 @@
7472
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
7573
npu_wait_tensor)
7674

77-
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
78-
7975

8076
class CustomDeepseekV2SiluAndMul(SiluAndMul):
8177

@@ -240,9 +236,8 @@ def __init__(
240236

241237
ascend_config = get_ascend_config()
242238
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
243-
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
244239
self.enable_multistream_moe = \
245-
ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2
240+
ascend_config.torchair_graph_config.enable_multistream_moe
246241

247242
self.gate = ReplicatedLinear(config.hidden_size,
248243
config.n_routed_experts,
@@ -312,22 +307,6 @@ def forward(
312307
enable_force_load_balance = False
313308
if hasattr(attn_metadata, 'with_prefill_across_dp'):
314309
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
315-
num_tokens, hidden_size = hidden_states.shape
316-
old_hidden_states = hidden_states
317-
use_separated_shared_experts = (self.shared_experts is not None
318-
and not self.enable_multistream_moe)
319-
320-
if self.tp_size > 1:
321-
if (VLLM_ENABLE_MC2
322-
and not is_prefill) or not (self.torchair_graph_enabled or
323-
self.ep_group.world_size == 1):
324-
if num_tokens < self.tp_size:
325-
hidden_states = nn.functional.pad(
326-
hidden_states, (0, 0, 0, self.tp_size - num_tokens))
327-
chunk_hidden_states = torch.tensor_split(hidden_states,
328-
self.tp_size,
329-
dim=0)
330-
hidden_states = chunk_hidden_states[self.tp_rank]
331310

332311
# router_logits: (num_tokens, n_experts)
333312
router_logits, _ = self.gate(hidden_states)
@@ -338,34 +317,14 @@ def forward(
338317
is_prefill=is_prefill,
339318
top_k=CustomDeepseekV2MoE.top_k,
340319
enable_force_load_balance=enable_force_load_balance,
341-
shared_experts=(self.shared_experts
342-
if not use_separated_shared_experts else None),
320+
shared_experts=self.shared_experts,
343321
)
344322

345-
if not isinstance(experts_hidden_states, tuple):
346-
hidden_states = experts_hidden_states * self.routed_scaling_factor
347-
else:
348-
hidden_states = (
349-
experts_hidden_states[0] * self.routed_scaling_factor +
350-
experts_hidden_states[1])
351-
352-
if self.tp_size > 1:
353-
if (VLLM_ENABLE_MC2
354-
and not is_prefill) or not (self.torchair_graph_enabled or
355-
self.ep_group.world_size == 1):
356-
dist.all_gather(list(chunk_hidden_states), hidden_states,
357-
self.tp_group)
358-
hidden_states = torch.cat(chunk_hidden_states, dim=0)
359-
if num_tokens < self.tp_size:
360-
hidden_states = hidden_states[:num_tokens]
361-
else:
362-
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
323+
hidden_states = (
324+
experts_hidden_states[0] * self.routed_scaling_factor +
325+
experts_hidden_states[1])
363326

364-
if use_separated_shared_experts:
365-
hidden_states = hidden_states + self.shared_experts(
366-
old_hidden_states)
367-
368-
return hidden_states.view(num_tokens, hidden_size)
327+
return hidden_states
369328

370329

371330
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):

0 commit comments

Comments
 (0)