diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 122f7b9654..fd813aa178 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -50,6 +50,7 @@ def __init__( include_finished_set, log_stats) self.scheduled_req_ids: set[str] = set() self.running: list[Request] = [] + self.include_finished_set = include_finished_set if self.vllm_config.kv_transfer_config is not None and \ self.vllm_config.kv_transfer_config.is_kv_consumer: diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 264a798eeb..92a715050c 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -216,6 +216,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # when profile runs, force experts to load balanced tokens # to avoid high memory consumption on a single rank. # TODO: need a better flag to indicate whether in profile run or not. + + #ep=1 etp=tp + if self.experts.ep_group.world_size == 1: + return self.forward_etp(hidden_states) + if attn_metadata is None: # for profile run is_prefill = True @@ -270,6 +275,44 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states.view(num_tokens, hidden_dim) + def forward_etp(self, + hidden_states: torch.Tensor, + is_prefill: bool = False) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + + if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill and self.tp_size > 1: + chunks = torch.chunk(hidden_states, self.tp_size, dim=0) + hidden_states = chunks[self.tp_rank_in_group] + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor + + if self.tp_size > 1: + if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: + final_hidden_states = torch.zeros([num_tokens, hidden_dim], + dtype=self.params_dtype, + device="npu") + dist.all_gather_into_tensor(final_hidden_states, hidden_states, + self.tp_group) + hidden_states = final_hidden_states + else: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + + if self.n_shared_experts is not None: + hidden_states = hidden_states + shared_output + + return hidden_states.view(num_tokens, hidden_dim) + class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 74a292d576..2d50f603a5 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -15,6 +15,7 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/kernels/test_moe.py +import os from typing import Callable, Optional import torch @@ -25,6 +26,7 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import get_dp_group +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod, determine_expert_map) @@ -329,6 +331,7 @@ def fused_experts( num_experts = w1.shape[0] dtype = hidden_states.dtype device = hidden_states.device + topk_weights = topk_weights.to(dtype) # assert dtype in [torch.float32, torch.float16, torch.bfloat16 # ], "Only float32, float16, and bfloat16 are supported" @@ -761,16 +764,17 @@ def __init__( self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group else: - # Adjust TP size for DP attention - # haven't test its functionality yet, may remove in the future - - self.moe_parallel_config.tp_rank = self.tp_size * self.dp_rank - self.moe_parallel_config.ep_rank = 0 - self.moe_parallel_config.tp_size = self.tp_size * self.dp_size - self.moe_parallel_config.ep_size = 1 - + self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group + self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group self.local_num_experts, self.expert_map = (self.global_num_experts, None) + + self.enable_graph_mode = False + additional_config = get_current_vllm_config().additional_config + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) + if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") @@ -808,6 +812,7 @@ def __init__( in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): moe_quant_params["intermediate_size_full"] = intermediate_size + self.ep_group = get_ep_group() self.quant_method.create_weights(layer=self, **moe_quant_params) def forward(self, @@ -817,6 +822,10 @@ def forward(self, enable_force_load_balance: bool = False, top_k=None): assert self.quant_method is not None + #ep=1 etp=tp + if self.ep_group.world_size == 1: + return self.forward_etp(hidden_states, router_logits, is_prefill, + top_k) if top_k: real_top_k = top_k @@ -852,3 +861,86 @@ def forward(self, final_hidden_states) return final_hidden_states + + def forward_etp(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_prefill: bool, + top_k=None): + assert self.quant_method is not None + + if top_k: + real_top_k = top_k + else: + real_top_k = self.top_k + + # MC2 ag/rs broadcast/all_reduce + # prefill_req x x √ + # decode_req √ x √ + # graph_mode √ √ x + if self.dp_size > 1: + if VLLM_ENABLE_MC2 and not is_prefill: + ... + elif int(os.environ.get("USING_LCCL_COM", + '0')) == 1: # type: ignore + hidden_states = get_dp_group().all_gather( + hidden_states, 0, False) + router_logits = get_dp_group().all_gather( + router_logits, 0, False) + elif self.enable_graph_mode and not is_prefill: + hidden_states = get_dp_group().all_gather(hidden_states, 0) + router_logits = get_dp_group().all_gather(router_logits, 0) + else: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_dp_cpu) + router_logits = self.naive_multicast(router_logits, + cu_tokens_across_dp_cpu) + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=real_top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + is_prefill=is_prefill) + + if self.dp_size > 1: + if VLLM_ENABLE_MC2 and not is_prefill: + ... + elif int(os.environ.get("USING_LCCL_COM", + '0')) == 1: # type: ignore + final_hidden_states = dist._functional_collectives.reduce_scatter_tensor( + final_hidden_states, + "sum", + scatter_dim=0, + group=get_dp_group().device_group) + elif self.enable_graph_mode and not is_prefill: + final_hidden_states = dist._functional_collectives.reduce_scatter_tensor( + final_hidden_states, + "sum", + scatter_dim=0, + group=get_dp_group().device_group) + else: + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + all_hidden_states = get_dp_group().all_reduce( + final_hidden_states) + final_hidden_states = all_hidden_states[start:end, :] + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index a84736425a..68ade1aaaa 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -340,6 +340,7 @@ def fused_experts(hidden_states: torch.Tensor, num_experts = w1.shape[0] dtype = hidden_states.dtype device = hidden_states.device + topk_weights = topk_weights.to(dtype) if expert_map is not None: # Generate token indices and flatten