From 1d9d03733d018c74303adcfff5dd8b8a3b56f168 Mon Sep 17 00:00:00 2001 From: ttanzhiqiang <389825161@qq.com> Date: Wed, 28 May 2025 16:08:58 +0800 Subject: [PATCH 1/8] fix ep=1 etp=16 Signed-off-by: ttanzhiqiang <389825161@qq.com> --- vllm_ascend/core/scheduler.py | 1 + vllm_ascend/models/deepseek_v2.py | 44 ++++++++ vllm_ascend/ops/fused_moe.py | 128 ++++++++++++++++++----- vllm_ascend/quantization/w8a8_dynamic.py | 1 + 4 files changed, 148 insertions(+), 26 deletions(-) diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 122f7b9654..fd813aa178 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -50,6 +50,7 @@ def __init__( include_finished_set, log_stats) self.scheduled_req_ids: set[str] = set() self.running: list[Request] = [] + self.include_finished_set = include_finished_set if self.vllm_config.kv_transfer_config is not None and \ self.vllm_config.kv_transfer_config.is_kv_consumer: diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 5e97444157..05ef2ff5e4 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -215,6 +215,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # when profile runs, force experts to load balanced tokens # to avoid high memory consumption on a single rank. # TODO: need a better flag to indicate whether in profile run or not. + + #ep=1 etp=tp + if self.experts.ep_group.world_size == 1: + return self.forward_etp(hidden_states) + if attn_metadata is None: # for profile run is_prefill = True @@ -269,6 +274,45 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states.view(num_tokens, hidden_dim) + def forward_etp(self, + hidden_states: torch.Tensor, + is_prefill: bool = False) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + + if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill and self.tp_size > 1: + chunks = torch.chunk(hidden_states, self.tp_size, dim=0) + hidden_states = chunks[self.tp_rank_in_group] + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor + + if self.tp_size > 1: + if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: + final_hidden_states = torch.zeros([num_tokens, hidden_dim], + dtype=self.params_dtype, + device="npu") + dist.all_gather_into_tensor(final_hidden_states, hidden_states, + self.tp_group) + hidden_states = final_hidden_states + else: + hidden_states = tensor_model_parallel_all_reduce( + hidden_states) + + if self.n_shared_experts is not None: + hidden_states = hidden_states + shared_output + + return hidden_states.view(num_tokens, hidden_dim) + class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 6313a7506b..990cc9c96b 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -25,6 +25,7 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import get_dp_group +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map) @@ -337,6 +338,7 @@ def fused_experts( num_experts = w1.shape[0] dtype = hidden_states.dtype device = hidden_states.device + topk_weights = topk_weights.to(dtype) # assert dtype in [torch.float32, torch.float16, torch.bfloat16 # ], "Only float32, float16, and bfloat16 are supported" @@ -748,6 +750,7 @@ def __init__( vllm_parallel_config=vllm_config.parallel_config)) self.moe_parallel_config.ep_size = get_ep_group().world_size + self.moe_parallel_config.tp_size = get_etp_group().world_size self.top_k = top_k self.num_experts = num_experts @@ -767,35 +770,19 @@ def __init__( self.expert_map = None self.activation = activation - if self.ep_size > 1: - # Create a tensor of size num_experts filled with -1 - self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, - get_ep_group().rank_in_group, self.global_num_experts) - if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): - self.tp_rank = get_etp_group().rank_in_group - self.ep_rank = get_ep_group().rank_in_group - else: - self.moe_parallel_config.tp_rank = get_etp_group( - ).rank_in_group - self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group - + # Create a tensor of size num_experts filled with -1 + self.local_num_experts, self.expert_map = determine_expert_map( + self.ep_size, + get_ep_group().rank_in_group, self.global_num_experts) + if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): + self.tp_rank = get_etp_group().rank_in_group + self.ep_rank = get_ep_group().rank_in_group else: - # Adjust TP size for DP attention - # haven't test its functionality yet, may remove in the future - if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): - self.tp_rank = self.tp_size * self.dp_rank - self.ep_rank = 0 - self.tp_size = self.tp_size * self.dp_size - self.ep_size = 1 - else: - self.moe_parallel_config.tp_rank = self.tp_size * self.dp_rank - self.moe_parallel_config.ep_rank = 0 - self.moe_parallel_config.tp_size = self.tp_size * self.dp_size - self.moe_parallel_config.ep_size = 1 - + self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group + self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group self.local_num_experts, self.expert_map = (self.global_num_experts, None) + if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") @@ -839,6 +826,7 @@ def __init__( in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): moe_quant_params["intermediate_size_full"] = intermediate_size + self.ep_group = get_ep_group() self.quant_method.create_weights(layer=self, **moe_quant_params) def forward(self, @@ -848,6 +836,9 @@ def forward(self, enable_force_load_balance: bool = False, top_k=None): assert self.quant_method is not None + #ep=1 etp=tp + if self.ep_group.world_size == 1: + return self.forward_etp(hidden_states, router_logits, is_prefill, top_k) if top_k: real_top_k = top_k @@ -883,3 +874,88 @@ def forward(self, final_hidden_states) return final_hidden_states + + def forward_etp(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_prefill: bool, + top_k=None): + assert self.quant_method is not None + + if top_k: + real_top_k = top_k + else: + real_top_k = self.top_k + + # MC2 ag/rs boardcast/all_reduce + # prefill_req x x √ + # decode_req √ x √ + # graph_mode √ √ x + if self.dp_size > 1: + if envs.VLLM_ENABLE_MC2 and not is_prefill: + ... + elif int(os.environ.get("USING_LCCL_COM", + '0')) == 1: # type: ignore + hidden_states = get_dp_group().all_gather( + hidden_states, 0, False) + router_logits = get_dp_group().all_gather( + router_logits, 0, False) + elif self.enable_graph_mode and not is_prefill: + hidden_states = get_dp_group().all_gather(hidden_states, 0) + router_logits = get_dp_group().all_gather(router_logits, 0) + else: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_dp_cpu) + router_logits = self.naive_multicast(router_logits, + cu_tokens_across_dp_cpu) + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=real_top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + is_prefill=is_prefill) + + + + if self.dp_size > 1: + if envs.VLLM_ENABLE_MC2 and not is_prefill: + ... + elif int(os.environ.get("USING_LCCL_COM", + '0')) == 1: # type: ignore + final_hidden_states = dist._functional_collectives.reduce_scatter_tensor( + final_hidden_states, + "sum", + scatter_dim=0, + group=get_dp_group().device_group) + elif self.enable_graph_mode and not is_prefill: + final_hidden_states = dist._functional_collectives.reduce_scatter_tensor( + final_hidden_states, + "sum", + scatter_dim=0, + group=get_dp_group().device_group) + else: + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + all_hidden_states = get_dp_group().all_reduce( + final_hidden_states) + final_hidden_states = all_hidden_states[start:end, :] + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 0f54b012f1..2d61c9ff0d 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -342,6 +342,7 @@ def fused_experts(hidden_states: torch.Tensor, num_experts = w1.shape[0] dtype = hidden_states.dtype device = hidden_states.device + topk_weights = topk_weights.to(dtype) if expert_map is not None: # Generate token indices and flatten From 0cb29fad155d1576263ddba469563ee191b832f0 Mon Sep 17 00:00:00 2001 From: ttanzhiqiang <389825161@qq.com> Date: Wed, 28 May 2025 16:25:40 +0800 Subject: [PATCH 2/8] update Signed-off-by: ttanzhiqiang <389825161@qq.com> --- vllm_ascend/models/deepseek_v2.py | 7 +++---- vllm_ascend/ops/fused_moe.py | 26 +++++++++++++------------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 05ef2ff5e4..6e60803746 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -275,8 +275,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states.view(num_tokens, hidden_dim) def forward_etp(self, - hidden_states: torch.Tensor, - is_prefill: bool = False) -> torch.Tensor: + hidden_states: torch.Tensor, + is_prefill: bool = False) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -305,8 +305,7 @@ def forward_etp(self, self.tp_group) hidden_states = final_hidden_states else: - hidden_states = tensor_model_parallel_all_reduce( - hidden_states) + hidden_states = tensor_model_parallel_all_reduce(hidden_states) if self.n_shared_experts is not None: hidden_states = hidden_states + shared_output diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 990cc9c96b..f69fba3d18 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -40,6 +40,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +import os import vllm_ascend.envs as envs_ascend from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group @@ -838,7 +839,8 @@ def forward(self, assert self.quant_method is not None #ep=1 etp=tp if self.ep_group.world_size == 1: - return self.forward_etp(hidden_states, router_logits, is_prefill, top_k) + return self.forward_etp(hidden_states, router_logits, is_prefill, + top_k) if top_k: real_top_k = top_k @@ -876,10 +878,10 @@ def forward(self, return final_hidden_states def forward_etp(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - is_prefill: bool, - top_k=None): + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_prefill: bool, + top_k=None): assert self.quant_method is not None if top_k: @@ -887,12 +889,12 @@ def forward_etp(self, else: real_top_k = self.top_k - # MC2 ag/rs boardcast/all_reduce + # MC2 ag/rs broadcast/all_reduce # prefill_req x x √ # decode_req √ x √ # graph_mode √ √ x if self.dp_size > 1: - if envs.VLLM_ENABLE_MC2 and not is_prefill: + if VLLM_ENABLE_MC2 and not is_prefill: ... elif int(os.environ.get("USING_LCCL_COM", '0')) == 1: # type: ignore @@ -902,7 +904,7 @@ def forward_etp(self, router_logits, 0, False) elif self.enable_graph_mode and not is_prefill: hidden_states = get_dp_group().all_gather(hidden_states, 0) - router_logits = get_dp_group().all_gather(router_logits, 0) + router_logits = get_dp_group().all_gather(router_logits, 0) else: cu_tokens_across_dp_cpu = get_forward_context( ).dp_metadata.cu_tokens_across_dp_cpu @@ -928,10 +930,8 @@ def forward_etp(self, e_score_correction_bias=self.e_score_correction_bias, is_prefill=is_prefill) - - if self.dp_size > 1: - if envs.VLLM_ENABLE_MC2 and not is_prefill: + if VLLM_ENABLE_MC2 and not is_prefill: ... elif int(os.environ.get("USING_LCCL_COM", '0')) == 1: # type: ignore @@ -955,7 +955,7 @@ def forward_etp(self, final_hidden_states = all_hidden_states[start:end, :] if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) return final_hidden_states From 8f9e635c1feae8d7ac3151a60c240cb79d1f1a0d Mon Sep 17 00:00:00 2001 From: ttanzhiqiang <389825161@qq.com> Date: Wed, 28 May 2025 16:46:39 +0800 Subject: [PATCH 3/8] update Signed-off-by: ttanzhiqiang <389825161@qq.com> --- vllm_ascend/ops/fused_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index f69fba3d18..6ecfb589c7 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -37,10 +37,11 @@ else: MoEConfig = None +import os + from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -import os import vllm_ascend.envs as envs_ascend from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group From 6bd5891173b30a9586c15bec7af999d594d740cb Mon Sep 17 00:00:00 2001 From: ttanzhiqiang <389825161@qq.com> Date: Wed, 28 May 2025 18:19:40 +0800 Subject: [PATCH 4/8] update graph Signed-off-by: ttanzhiqiang <389825161@qq.com> --- vllm_ascend/ops/fused_moe.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 6ecfb589c7..b4db921b96 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -771,7 +771,12 @@ def __init__( self.e_score_correction_bias = e_score_correction_bias self.expert_map = None self.activation = activation - + self.enable_graph_mode = False + additional_config = get_current_vllm_config().additional_config + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) + # Create a tensor of size num_experts filled with -1 self.local_num_experts, self.expert_map = determine_expert_map( self.ep_size, From d201ca06a2e0f317e4db34c93787e34032333030 Mon Sep 17 00:00:00 2001 From: ttanzhiqiang <389825161@qq.com> Date: Wed, 28 May 2025 19:11:44 +0800 Subject: [PATCH 5/8] update Signed-off-by: ttanzhiqiang <389825161@qq.com> --- vllm_ascend/ops/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index b4db921b96..38aead3b9f 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -776,7 +776,7 @@ def __init__( if additional_config: self.enable_graph_mode = additional_config.get( "enable_graph_mode", False) - + # Create a tensor of size num_experts filled with -1 self.local_num_experts, self.expert_map = determine_expert_map( self.ep_size, From eccdcc2bb0a10b0d7f6ebaecea0923f363f7a626 Mon Sep 17 00:00:00 2001 From: ttanzhiqiang <389825161@qq.com> Date: Thu, 29 May 2025 21:56:07 +0800 Subject: [PATCH 6/8] update Signed-off-by: ttanzhiqiang <389825161@qq.com> --- vllm_ascend/ops/fused_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 82bd0ddc62..1ec18b156b 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -15,6 +15,7 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/kernels/test_moe.py +import os from typing import Callable, Optional import torch From 3c254d437567869be0602fd0e624e7e6c8a2137f Mon Sep 17 00:00:00 2001 From: ttanzhiqiang <389825161@qq.com> Date: Thu, 29 May 2025 22:39:13 +0800 Subject: [PATCH 7/8] update ep>1 Signed-off-by: ttanzhiqiang <389825161@qq.com> --- vllm_ascend/ops/fused_moe.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 1ec18b156b..de2aa9f826 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -753,22 +753,28 @@ def __init__( self.e_score_correction_bias = e_score_correction_bias self.expert_map = None self.activation = activation + + if self.ep_size > 1: + # Create a tensor of size num_experts filled with -1 + self.local_num_experts, self.expert_map = determine_expert_map( + self.ep_size, + get_ep_group().rank_in_group, self.global_num_experts) + + self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group + self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group + + else: + self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group + self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group + self.local_num_experts, self.expert_map = (self.global_num_experts, + None) + self.enable_graph_mode = False additional_config = get_current_vllm_config().additional_config if additional_config: self.enable_graph_mode = additional_config.get( "enable_graph_mode", False) - # Create a tensor of size num_experts filled with -1 - self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, - get_ep_group().rank_in_group, self.global_num_experts) - - self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group - self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group - self.local_num_experts, self.expert_map = (self.global_num_experts, - None) - if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") From 4d3a521e45a99189b838c6132d857efde4319998 Mon Sep 17 00:00:00 2001 From: ttanzhiqiang <389825161@qq.com> Date: Thu, 29 May 2025 22:42:04 +0800 Subject: [PATCH 8/8] update Signed-off-by: ttanzhiqiang <389825161@qq.com> --- vllm_ascend/ops/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index de2aa9f826..2d50f603a5 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -768,7 +768,7 @@ def __init__( self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group self.local_num_experts, self.expert_map = (self.global_num_experts, None) - + self.enable_graph_mode = False additional_config = get_current_vllm_config().additional_config if additional_config: