-
Notifications
You must be signed in to change notification settings - Fork 547
[refactor] Refactoring AscendFusedMoE #1229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,7 +28,6 @@ | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| import torch_npu | ||
| import vllm.envs as envs | ||
| from torch import nn | ||
|
|
@@ -37,7 +36,7 @@ | |
| from vllm.config import CacheConfig, ModelConfig, VllmConfig | ||
| from vllm.distributed import (get_pp_group, | ||
| get_tensor_model_parallel_world_size, | ||
| get_tp_group, tensor_model_parallel_all_reduce) | ||
| get_tp_group) | ||
| from vllm.distributed.parallel_state import get_dp_group | ||
| from vllm.forward_context import get_forward_context | ||
| from vllm.model_executor.layers.activation import SiluAndMul | ||
|
|
@@ -54,9 +53,9 @@ | |
| from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
| ParallelLMHead, VocabParallelEmbedding) | ||
| from vllm.model_executor.models.deepseek_v2 import \ | ||
| DeepseekV2ForCausalLM # ruff: noqa: E501 | ||
| DeepseekV2ForCausalLM # noqa: E501 | ||
| from vllm.model_executor.models.deepseek_v2 import \ | ||
| yarn_get_mscale # ruff: noqa: E501 | ||
| yarn_get_mscale # noqa: E501 | ||
| from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention, | ||
| DeepseekV2DecoderLayer, | ||
| DeepseekV2MLAAttention) | ||
|
|
@@ -65,7 +64,6 @@ | |
| maybe_prefix) | ||
| from vllm.sequence import IntermediateTensors | ||
|
|
||
| import vllm_ascend.envs as envs_ascend | ||
| from vllm_ascend.ascend_config import get_ascend_config | ||
| from vllm_ascend.distributed.parallel_state import get_ep_group | ||
| from vllm_ascend.ops.fused_moe import AscendFusedMoE | ||
|
|
@@ -74,8 +72,6 @@ | |
| from vllm_ascend.utils import (dispose_tensor, npu_stream_switch, | ||
| npu_wait_tensor) | ||
|
|
||
| VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 | ||
|
|
||
|
|
||
| class CustomDeepseekV2SiluAndMul(SiluAndMul): | ||
|
|
||
|
|
@@ -240,9 +236,8 @@ def __init__( | |
|
|
||
| ascend_config = get_ascend_config() | ||
| self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled | ||
| # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on | ||
| self.enable_multistream_moe = \ | ||
| ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2 | ||
| ascend_config.torchair_graph_config.enable_multistream_moe | ||
|
|
||
| self.gate = ReplicatedLinear(config.hidden_size, | ||
| config.n_routed_experts, | ||
|
|
@@ -312,22 +307,6 @@ def forward( | |
| enable_force_load_balance = False | ||
| if hasattr(attn_metadata, 'with_prefill_across_dp'): | ||
| is_prefill = is_prefill or attn_metadata.with_prefill_across_dp | ||
| num_tokens, hidden_size = hidden_states.shape | ||
| old_hidden_states = hidden_states | ||
| use_separated_shared_experts = (self.shared_experts is not None | ||
| and not self.enable_multistream_moe) | ||
|
|
||
| if self.tp_size > 1: | ||
| if (VLLM_ENABLE_MC2 | ||
| and not is_prefill) or not (self.torchair_graph_enabled or | ||
| self.ep_group.world_size == 1): | ||
| if num_tokens < self.tp_size: | ||
| hidden_states = nn.functional.pad( | ||
| hidden_states, (0, 0, 0, self.tp_size - num_tokens)) | ||
| chunk_hidden_states = torch.tensor_split(hidden_states, | ||
| self.tp_size, | ||
| dim=0) | ||
| hidden_states = chunk_hidden_states[self.tp_rank] | ||
|
|
||
| # router_logits: (num_tokens, n_experts) | ||
| router_logits, _ = self.gate(hidden_states) | ||
|
|
@@ -338,34 +317,14 @@ def forward( | |
| is_prefill=is_prefill, | ||
| top_k=CustomDeepseekV2MoE.top_k, | ||
| enable_force_load_balance=enable_force_load_balance, | ||
| shared_experts=(self.shared_experts | ||
| if not use_separated_shared_experts else None), | ||
| shared_experts=self.shared_experts, | ||
| ) | ||
|
||
|
|
||
| if not isinstance(experts_hidden_states, tuple): | ||
| hidden_states = experts_hidden_states * self.routed_scaling_factor | ||
| else: | ||
| hidden_states = ( | ||
| experts_hidden_states[0] * self.routed_scaling_factor + | ||
| experts_hidden_states[1]) | ||
|
|
||
| if self.tp_size > 1: | ||
| if (VLLM_ENABLE_MC2 | ||
| and not is_prefill) or not (self.torchair_graph_enabled or | ||
| self.ep_group.world_size == 1): | ||
| dist.all_gather(list(chunk_hidden_states), hidden_states, | ||
| self.tp_group) | ||
| hidden_states = torch.cat(chunk_hidden_states, dim=0) | ||
| if num_tokens < self.tp_size: | ||
| hidden_states = hidden_states[:num_tokens] | ||
| else: | ||
| hidden_states = tensor_model_parallel_all_reduce(hidden_states) | ||
| hidden_states = ( | ||
| experts_hidden_states[0] * self.routed_scaling_factor + | ||
| experts_hidden_states[1]) | ||
|
|
||
| if use_separated_shared_experts: | ||
| hidden_states = hidden_states + self.shared_experts( | ||
| old_hidden_states) | ||
|
|
||
| return hidden_states.view(num_tokens, hidden_size) | ||
| return hidden_states | ||
|
|
||
|
|
||
| class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why there are only delete in dbo, have you verified the functionality of this?