diff --git a/tests/e2e/multicard/moe/test_moe_comm.py b/tests/e2e/multicard/moe/test_moe_comm.py new file mode 100644 index 0000000000..b1de5e680f --- /dev/null +++ b/tests/e2e/multicard/moe/test_moe_comm.py @@ -0,0 +1,153 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +from types import SimpleNamespace + +import pytest +import torch +from transformers import PretrainedConfig +from vllm import forward_context + +from vllm_ascend.distributed import moe_comm_method +from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, + NativeAllGatherCommImpl) + + +@pytest.mark.parametrize("num_tokens", [16, 128]) +@pytest.mark.parametrize("hidden_size", [64, 128]) +@pytest.mark.parametrize("global_num_experts", [8, 16]) +@pytest.mark.parametrize("top_k_num", [2, 4]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("num_local_experts", [4, 8]) +@pytest.mark.parametrize("ep_rank", [0, 1]) +def test_all_gather_comm_impl( + num_tokens, + hidden_size, + global_num_experts, + top_k_num, + dtype, + num_local_experts, + ep_rank, +): + """ + Tests the AllGatherCommImpl against the NativeAllGatherCommImpl. + + This test compares the outputs of the NPU-optimized AllGatherCommImpl + with a native PyTorch implementation (NativeAllGatherCommImpl) to ensure + correctness across various configurations. + """ + if top_k_num > global_num_experts: + pytest.skip("top_k_num cannot be greater than global_num_experts") + if num_local_experts > global_num_experts: + pytest.skip( + "num_local_experts cannot be greater than global_num_experts") + + device = torch.device("npu") + hf_config = PretrainedConfig( + num_experts_per_tok=top_k_num, + num_experts=global_num_experts, + ) + + # Instantiate implementations + native_impl = NativeAllGatherCommImpl(device, dtype, hf_config) + + all_gather_impl = AllGatherCommImpl(device, dtype, hf_config) + + # TODO: Find out if this is the correct way to mock the forward context and ep group + # Mock get_forward_context to return an object with moe_comm_method + forward_context._forward_context = SimpleNamespace( + moe_comm_method=all_gather_impl) + # Mock get_ep_group to return a fake group with the specified ep_rank + fake_ep_group = SimpleNamespace(rank_in_group=ep_rank) + moe_comm_method.get_ep_group = lambda: fake_ep_group + + # --- Input Data --- + hidden_states = torch.randn(num_tokens, + hidden_size, + device=device, + dtype=dtype) + topk_ids = torch.randint(0, + global_num_experts, (num_tokens, top_k_num), + device=device, + dtype=torch.int32) + topk_weights = torch.rand(num_tokens, top_k_num, device=device).to(dtype) + topk_weights = torch.nn.functional.softmax(topk_weights, dim=1) + + num_experts = global_num_experts + + expert_map = None + if num_local_experts < global_num_experts: + # Create a map where some experts are local and some are not + expert_map = torch.full((global_num_experts, ), -1, device=device) + expert_map[ep_rank * num_local_experts:(ep_rank + 1) * + num_local_experts] = torch.arange(num_local_experts, + device=device) + num_experts = num_local_experts + + # --- Run Native Implementation (Golden Reference) --- + native_hidden_states_out = hidden_states.clone() + ( + native_permuted_hidden, + native_expert_tokens, + _, + ) = native_impl._pre_process(hidden_states, topk_ids, topk_weights, + expert_map, num_experts) + # Simulate MLP output + native_mlp_output = torch.randn_like(native_permuted_hidden) + native_impl._post_process(native_mlp_output, native_hidden_states_out) + + # --- Run AllGather Implementation --- + all_gather_hidden_states_out = hidden_states.clone() + ( + all_gather_permuted_hidden, + all_gather_expert_tokens, + _, + ) = torch.ops.vllm.moe_comm_pre_process(hidden_states, topk_ids, + topk_weights, expert_map, + num_experts) + + # Use the same simulated MLP output for a fair comparison + all_gather_mlp_output = native_mlp_output.clone() + + torch.ops.vllm.moe_comm_post_process(all_gather_mlp_output, + all_gather_hidden_states_out) + + # --- Assertions --- + # Define tolerance based on dtype + atol = 1e-3 if dtype == torch.float16 else 1e-2 + rtol = 1e-3 if dtype == torch.float16 else 1e-2 + + # 1. Compare expert_tokens from pre_process + assert torch.allclose(native_expert_tokens.to( + all_gather_expert_tokens.device), + all_gather_expert_tokens, + atol=atol, + rtol=rtol), "Expert tokens do not match." + + # 2. Compare permuted_hidden_states from pre_process + num_valid_tokens = native_expert_tokens.sum() + assert torch.allclose(native_permuted_hidden[:num_valid_tokens].to( + all_gather_permuted_hidden.device), + all_gather_permuted_hidden[:num_valid_tokens], + atol=atol, + rtol=rtol), "Permuted hidden states do not match." + + # 3. Compare final hidden_states from post_process + assert torch.allclose(native_hidden_states_out.to( + all_gather_hidden_states_out.device), + all_gather_hidden_states_out, + atol=atol, + rtol=rtol), "Final hidden states do not match." diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index c86253472f..c045ad6306 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -5,11 +5,12 @@ import torch from vllm.config import VllmConfig -from vllm.distributed import get_dp_group, get_ep_group, get_tp_group +from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_world_size) from vllm.forward_context import get_forward_context, set_forward_context import vllm_ascend.envs as envs -from vllm_ascend.platform import NPUPlatform +from vllm_ascend.distributed.moe_comm_method import MoECommMethod class FusedMoEState(Enum): @@ -54,6 +55,8 @@ def set_ascend_forward_context( num_tokens_across_dp: Optional[torch.Tensor] = None, with_prefill: bool = True, in_profile_run: bool = False, + reserved_mc2_mask: Optional[torch.Tensor] = None, + moe_comm_method: Optional[MoECommMethod] = None, num_actual_tokens: Optional[int] = None, ): """A context manager that stores the current forward context, @@ -66,6 +69,7 @@ def set_ascend_forward_context( num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp): forward_context = get_forward_context() + forward_context.moe_comm_method = moe_comm_method forward_context.with_prefill = with_prefill ep_size = (get_ep_group().world_size if vllm_config.parallel_config.enable_expert_parallel else 1) @@ -97,16 +101,17 @@ def set_ascend_forward_context( if num_tokens is not None: if num_actual_tokens is None: num_actual_tokens = num_tokens - tp_world_size = get_tp_group().world_size + tp_world_size = get_tensor_model_parallel_world_size() # NOTE: token num which need to pad to when mc2 forward_context.padded_num_tokens = math.ceil( max_tokens_across_dp / tp_world_size) * tp_world_size - mc2_mask = torch.zeros(forward_context.padded_num_tokens, - dtype=torch.bool, - device=NPUPlatform.device_type) - mc2_mask[:num_actual_tokens] = True - forward_context.mc2_mask = mc2_mask + if reserved_mc2_mask is not None: + mc2_mask = reserved_mc2_mask[:forward_context. + padded_num_tokens] + mc2_mask[:num_actual_tokens] = True + mc2_mask[num_actual_tokens:] = False + forward_context.mc2_mask = mc2_mask try: yield diff --git a/vllm_ascend/distributed/moe_comm_method.py b/vllm_ascend/distributed/moe_comm_method.py new file mode 100644 index 0000000000..f347ab06cb --- /dev/null +++ b/vllm_ascend/distributed/moe_comm_method.py @@ -0,0 +1,449 @@ +from abc import ABC, abstractmethod + +import torch +import torch_npu +from transformers.configuration_utils import PretrainedConfig +from vllm.distributed.parallel_state import get_ep_group, get_tp_group +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.utils import direct_register_custom_op + +from vllm_ascend.distributed.parallel_state import get_mc2_group +from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version + + +class MoECommMethod(ABC): + """Base class for MoE communication methods.""" + + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + hf_config: PretrainedConfig, + ): + self.device = device + self.dtype = dtype + self.top_k_num = getattr(hf_config, "num_experts_per_tok", 0) + # global_num_experts may be called num_experts or n_routed_experts in different models. + possible_keys = ["num_experts", "n_routed_experts"] + for key in possible_keys: + if hasattr(hf_config, key): + self.global_num_experts = getattr(hf_config, key) + break + else: + self.global_num_experts = 0 + + @abstractmethod + def _pre_process( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + expert_map: torch.Tensor, + num_experts: int, + ) -> tuple[torch.Tensor, torch.Tensor, int]: + """Pre-process before MLP. + + Args: + hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size) + topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num) + topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num) + expert_map (torch.Tensor): Tensor of shape (global_num_experts, ) + Mapping from global expert IDs to local expert IDs. + num_experts (int): Number of local experts (experts on this device). + + Returns: + tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing: + - permuted_hidden_states (torch.Tensor): Tensor of shape + (num_tokens * top_k_num, hidden_size) after permuting + hidden_states based on topk_ids. + - expert_tokens (torch.Tensor): Tensor of shape (num_experts, ) + Number of tokens assigned to each expert. + - group_list_type (int): Type of group list, 0 for `cumsum` + and 1 for `count`. This is mainly for `npu_grouped_matmul` + to determine how to handle the output. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + """ + pass + + @abstractmethod + def _post_process(self, mlp_output: torch.Tensor, + hidden_states: torch.Tensor) -> None: + """Post-process after MLP. + + Args: + mlp_output (torch.Tensor): Tensor of shape + (num_tokens * top_k_num, hidden_size) after MLP. + hidden_states (torch.Tensor): Tensor of shape + (num_tokens, hidden_size) to be updated with the final output. + """ + pass + + +class DummyCommImpl(MoECommMethod): + + def _pre_process( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + expert_map: torch.Tensor, + num_experts: int, + ) -> tuple[torch.Tensor, torch.Tensor, int]: + """Dummy implementation, see moe_comm_pre_process_fake for details.""" + return moe_comm_pre_process_fake(hidden_states, topk_ids, topk_weights, + expert_map, num_experts) + + def _post_process(self, mlp_output: torch.Tensor, + hidden_states: torch.Tensor) -> None: + """Dummy implementation that does nothing.""" + pass + + +class NativeAllGatherCommImpl(MoECommMethod): + """This implementation should be compatible with all scenarios. + + Note that this implementation purely consists of native PyTorch ops + and does not use any NPU-specific ops. So the performance may not be optimal. + But it is a good fallback for scenarios where NPU-specific ops are not available. + """ + + def _pre_process( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + expert_map: torch.Tensor, + num_experts: int, + ) -> tuple[torch.Tensor, torch.Tensor, int]: + num_tokens = hidden_states.shape[0] + + # Generate token indices and flatten + token_indices = torch.arange(num_tokens, + device=self.device, + dtype=torch.int64) + token_indices = (token_indices.unsqueeze(1).expand( + -1, self.top_k_num).reshape(-1)) + + # Flatten token-to-expert mappings and map to local experts + weights_flat = topk_weights.view(-1) + experts_flat = topk_ids.view(-1) + local_experts_flat = (expert_map[experts_flat] + if expert_map is not None else experts_flat) + + # Filter valid token-expert pairs + mask = local_experts_flat != -1 + # FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...] + # So we need to filter out invalid tokens by zeroing their weights. + # This is a workaround and should be removed after the issue is fixed + filtered_weights = torch.where(mask, weights_flat, + torch.zeros_like(weights_flat)).to( + self.dtype) + filtered_experts = torch.where( + mask, + local_experts_flat, + torch.full_like(local_experts_flat, num_experts), + ).to(topk_ids.dtype) + + # Sort by local expert IDs + sort_indices = torch.argsort(filtered_experts.view(torch.float32)) + self.sorted_token_indices = token_indices[sort_indices] + self.sorted_weights = filtered_weights[sort_indices] + + # Compute token counts with minlength of num_experts + # This is equivalent to but faster than: + # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] + token_counts = torch.zeros(num_experts + 1, + device=self.device, + dtype=torch.int64) + ones = torch.ones_like(filtered_experts, dtype=torch.int64) + token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) + expert_tokens = token_counts[:num_experts] + + # Rearrange hidden_states + permuted_hidden_states = hidden_states[self.sorted_token_indices] + + group_list_type = 1 # `count` mode + + return permuted_hidden_states, expert_tokens, group_list_type + + def _post_process(self, mlp_output: torch.Tensor, + hidden_states: torch.Tensor) -> None: + mlp_output = mlp_output * self.sorted_weights.unsqueeze(1) + + final_hidden_states = torch.zeros_like(hidden_states) + final_hidden_states.index_add_(0, self.sorted_token_indices, + mlp_output) + + hidden_states[:] = final_hidden_states + + +class AllGatherCommImpl(MoECommMethod): + """This implementation is the same as NativeAllGatherCommImpl, + but uses NPU-specific ops for better performance. + + This implementation should be compatible with all scenarios, and + thus it is the default implementation for MoE communication methods. + It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing + and `torch_npu.npu_moe_token_unpermute` for post-processing + to handle the token-to-expert mapping and communication efficiently. + + NOTE(Yizhou): TBH, it is really weird that we were supposed to use + `torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing` + or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute` + for pre-processing and post-processing, respectively. + But `npu_moe_finalize_routing` will lead to accuracy issues so we have to + use `torch_npu.npu_moe_token_unpermute` instead. + This is a workaround and should be removed after the issue is fixed. + """ + + def _pre_process( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + expert_map: torch.Tensor, # noqa: F841 + num_experts: int, + ) -> tuple[torch.Tensor, torch.Tensor, int]: + num_tokens = hidden_states.shape[0] + + self.topk_weights = topk_weights + self.topk_ids = topk_ids + + first_expert_idx = 0 + if expert_map is not None: + # FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...] + # So we need to filter out invalid tokens by zeroing their weights. + # This is a workaround and should be removed after the issue is fixed + mask = expert_map[topk_ids] != -1 + # NOTE: This is equivalent to self.topk_weights[~mask] = 0.0, + # but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph + self.topk_weights = torch.where(mask, topk_weights, 0.0) + + first_expert_idx = get_ep_group().rank_in_group * num_experts + last_expert_idx = first_expert_idx + num_experts + + permuted_hidden_states, expanded_row_idx, expert_tokens, _ = ( + torch_npu.npu_moe_init_routing_v2( + hidden_states, + topk_ids, + active_num=num_tokens * self.top_k_num, + expert_num=self.global_num_experts, + expert_tokens_num_type=1, # Only support `count` mode now + expert_tokens_num_flag=True, # Output `expert_tokens` + active_expert_range=[first_expert_idx, last_expert_idx], + quant_mode=-1, + )) + self.expanded_row_idx = expanded_row_idx + permuted_hidden_states = permuted_hidden_states + + group_list_type = 1 # `count` mode + + return permuted_hidden_states, expert_tokens, group_list_type + + def _post_process(self, mlp_output: torch.Tensor, + hidden_states: torch.Tensor) -> None: + hidden_states[:] = torch_npu.npu_moe_token_unpermute( + permuted_tokens=mlp_output, + sorted_indices=self.expanded_row_idx, + probs=self.topk_weights) + + +class MC2CommImpl(MoECommMethod): + """This implementation is for the scenarios listed below: + 1. `enable_expert_parallel=True`. + 2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available. + 3. `enable_expert_parallel=False` is not supported. + + This implementation uses the MC2 communication method, which is optimized for + Communication and Computation parallelism on Ascend devices. + """ + + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + hf_config: PretrainedConfig, + ): + super().__init__(device, dtype, hf_config) + + # Shared communication configurations + ep_group = get_mc2_group() + self.ep_rank_id = ep_group.rank_in_group + self.ep_world_size = ep_group.world_size + self.tp_world_size = get_tp_group().world_size + + device_group = ep_group.device_group + local_rank = torch.distributed.get_rank(group=device_group) + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) + + # Feature flags + self.enable_dispatch_v2 = hasattr(torch_npu, + "npu_moe_distribute_dispatch_v2") + self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3 + self.need_extra_args = self.is_ascend_a3 # or is_torchair + + # Intermediate tensors to be passed from pre_process to post_process + self.topk_ids = None + self.topk_weights = None + self.mc2_mask = None + self.assist_info_for_combine = None + self.ep_recv_counts = None + self.tp_recv_counts = None + + def _pre_process( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + expert_map: torch.Tensor, + num_experts: int, + ) -> tuple[torch.Tensor, torch.Tensor, int]: + # Store tensors needed for post_process + self.topk_ids = topk_ids + self.topk_weights = topk_weights.to(torch.float32) + self.mc2_mask = get_forward_context().mc2_mask + + dispatch_kwargs = { + "x": hidden_states, + "expert_ids": self.topk_ids, + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": self.global_num_experts, + "global_bs": 0, + "scales": None, + "quant_mode": 0, + "group_ep": self.moe_all_to_all_group_name, + "ep_world_size": self.ep_world_size, + "ep_rank_id": self.ep_rank_id, + } + + if self.need_extra_args: + dispatch_kwargs.update({ + "group_tp": self.moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if self.is_ascend_a3 and self.enable_dispatch_v2: + dispatch_kwargs.update({ + "x_active_mask": self.mc2_mask, + }) + + dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch + + ( + permuted_hidden_states, + _, # dynamic_scale is not used + self.assist_info_for_combine, + expert_tokens, + self.ep_recv_counts, + self.tp_recv_counts, + ) = dispatch(**dispatch_kwargs)[:6] + + group_list_type = 1 + + return permuted_hidden_states, expert_tokens, group_list_type + + def _post_process(self, mlp_output: torch.Tensor, + hidden_states: torch.Tensor) -> None: + combine_kwargs = { + "expand_x": mlp_output, + "expert_ids": self.topk_ids, + "expert_scales": self.topk_weights, + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": self.global_num_experts, + "global_bs": 0, + "ep_send_counts": self.ep_recv_counts, + "group_ep": self.moe_all_to_all_group_name, + "ep_world_size": self.ep_world_size, + "ep_rank_id": self.ep_rank_id, + } + + if self.enable_dispatch_v2: + combine_kwargs[ + "assist_info_for_combine"] = self.assist_info_for_combine + else: + combine_kwargs["expand_idx"] = self.assist_info_for_combine + + if self.need_extra_args: + combine_kwargs.update({ + "tp_send_counts": self.tp_recv_counts, + "group_tp": self.moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if self.is_ascend_a3 and self.enable_dispatch_v2: + combine_kwargs.update({ + "x_active_mask": self.mc2_mask, + }) + + combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine + + hidden_states[:] = combine(**combine_kwargs) + + +def moe_comm_pre_process( + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + expert_map: torch.Tensor, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor, int]: + """This function is a wrapper for the pre_process method of the + MoECommMethod instance stored in the ForwardContext. So it can be + used as a custom op in the vllm framework. + """ + forward_context: ForwardContext = get_forward_context() + self = forward_context.moe_comm_method + return self._pre_process(hidden_states, topk_ids, topk_weights, expert_map, + num_experts) + + +def moe_comm_pre_process_fake( + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + expert_map: torch.Tensor, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor, int]: + """This is a fake implementation of the pre_process method. + torch.compile will use this implementation to generate FX graph. + """ + top_k_num = topk_ids.shape[1] + permuted_hidden_states = hidden_states.repeat_interleave(top_k_num, dim=0) + expert_tokens = torch.zeros((num_experts, ), + dtype=torch.int64, + device=hidden_states.device) + group_list_type = 0 + return permuted_hidden_states, expert_tokens, group_list_type + + +def moe_comm_post_process(mlp_output: torch.Tensor, + hidden_states: torch.Tensor) -> None: + """This function is a wrapper for the post_process method of the + MoECommMethod instance stored in the ForwardContext. So it can be + used as a custom op in the vllm framework. + """ + forward_context: ForwardContext = get_forward_context() + self = forward_context.moe_comm_method + self._post_process(mlp_output, hidden_states) + return + + +direct_register_custom_op( + op_name="moe_comm_pre_process", + op_func=moe_comm_pre_process, + mutates_args=[], + fake_impl=moe_comm_pre_process_fake, + dispatch_key="PrivateUse1", +) + +direct_register_custom_op( + op_name="moe_comm_post_process", + op_func=moe_comm_post_process, + mutates_args=["hidden_states"], + fake_impl=lambda x, y: None, # No-op for fake implementation + dispatch_key="PrivateUse1", +) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index eeb8ec3223..b97aef7de1 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -19,12 +19,13 @@ import torch from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.layer import \ UnquantizedFusedMoEMethod from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge, - select_experts) +from vllm_ascend.ops.fused_moe import (fused_experts_moge, select_experts, + unified_fused_experts) from vllm_ascend.utils import is_310p original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ @@ -95,20 +96,18 @@ def forward_oot( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) - # If use aclgraph, we need to set max_num_tokens to make - # the input shape of `npu_moe_init_routing` fixed - max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None + moe_comm_method = get_forward_context().moe_comm_method - return fused_experts( + return unified_fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - top_k=top_k, + global_num_experts=global_num_experts, expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - max_num_tokens=max_num_tokens) + moe_comm_method=moe_comm_method, + ) UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index aec6e72264..aeb75cfa0d 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -43,6 +43,7 @@ from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.communication_op import \ data_parallel_reduce_scatter +from vllm_ascend.distributed.moe_comm_method import MoECommMethod from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( @@ -57,6 +58,62 @@ MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER +def unified_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_int8_w8a8: bool = False, + use_int4_w4a8: bool = False, + global_num_experts: Optional[int] = None, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + moe_comm_method: Optional[MoECommMethod] = None, + # For TorchAir graph + is_torchair: bool = False, + # For Cube/Vector parallel + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + # For load balance + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, +) -> torch.Tensor: + # Check constraints + assert hidden_states.shape[1] == w1.shape[2], ( + f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}") + + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + assert moe_comm_method is not None, "Missing communication context" + + num_experts = w1.shape[0] + + permuted_hidden_states, expert_tokens, group_list_type = torch.ops.vllm.moe_comm_pre_process( + hidden_states, topk_ids, topk_weights, expert_map, num_experts) + mlp_output = apply_mlp( + permuted_hidden_states, + w1, + w2, + expert_tokens, + group_list_type=group_list_type, + ) + torch.ops.vllm.moe_comm_post_process(mlp_output, hidden_states) + + return hidden_states + + def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, max_row_per_ep_rank: int, num_tokens: int, top_k: int) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index eb7ea8276c..f101ccdc7a 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -205,8 +205,15 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: register_ascend_customop() @classmethod - def get_attn_backend_cls(cls, selected_backend, head_size, dtype, - kv_cache_dtype, block_size, use_v1, use_mla): + def get_attn_backend_cls(cls, + selected_backend, + head_size, + dtype, + kv_cache_dtype, + block_size, + use_v1, + use_mla, + has_sink=False): if not use_v1: raise ValueError("vLLM Ascend does not support V0 engine.") diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 594649c6d4..3aeabc6e18 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -26,7 +26,7 @@ import weakref from contextlib import contextmanager, nullcontext from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast import numpy as np import numpy.typing as npt @@ -43,7 +43,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import (get_dp_group, get_pp_group, get_tp_group) -from vllm.forward_context import get_forward_context +from vllm.forward_context import DPMetadata, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -79,6 +79,9 @@ AscendMetadata) from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata from vllm_ascend.attention.mla_v1 import AscendMLAMetadata +from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, + DummyCommImpl, + MoECommMethod) from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler @@ -335,7 +338,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.use_aclgraph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE - and not self.model_config.enforce_eager) + and not self.model_config.enforce_eager and + not ascend_config.torchair_graph_config.enabled) self.aclgraph_batch_sizes = list( reversed( self.vllm_config.compilation_config.cudagraph_capture_sizes)) @@ -375,6 +379,14 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer + self.reserved_mc2_mask = torch.zeros( + 512, + dtype=torch.bool, + device=self.device, + ) + + self.moe_comm_method = AllGatherCommImpl + def check_batch_sizes_consistency(self) -> None: if not dist.is_initialized(): return @@ -1003,6 +1015,32 @@ def _gather_mm_embeddings( mm_embeds.append(mm_embeds_item) return mm_embeds + def get_dp_padding(self, + num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + """This implementation is derived from vLLM's `GPUModelRunner.get_dp_padding`. + Please note that vLLM may refactor or modify this function over time, + at present, we are using the version introduced in PR #18935. + """ + dp_size = self.vllm_config.parallel_config.data_parallel_size + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + + # For DP: Don't pad when setting enforce_eager. + # This lets us set enforce_eager on the prefiller in a P/D setup and + # still use ACL graphs (enabled by this padding) on the decoder. + + if dp_size == 1 or self.vllm_config.model_config.enforce_eager: + # Early exit. + return 0, None + + num_tokens_across_dp = DPMetadata.num_tokens_across_dp( + num_tokens, dp_size, dp_rank) + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() + num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * + dp_size, + device="cpu", + dtype=torch.int32) + return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding + def _process_reqs( self, scheduler_output: "SchedulerOutput", @@ -1025,6 +1063,11 @@ def _process_reqs( # Eager mode. num_input_tokens = total_num_scheduled_tokens + # Padding for DP + num_pad, num_tokens_across_dp_native = self.get_dp_padding( + num_input_tokens) + num_input_tokens += num_pad + modified_batch = self.attn_metadata_builder.reorder_batch( self.input_batch, scheduler_output) if modified_batch: @@ -1250,13 +1293,26 @@ def _process_reqs( for k, v in self.intermediate_tensors.items() }) + moe_comm_method = self.moe_comm_method + + # NOTE: Currently this padding logic is really messy, + # MC2 may not be available in eager mode + # TODO: Unify the padding logic between TorchAir and ACL Graph ASAP + if self.use_aclgraph: + num_tokens_across_dp = num_tokens_across_dp_native + else: + num_input_tokens = padded_num_tokens_across_dp + # Run forward pass with set_ascend_forward_context( attn_metadata, self.vllm_config, - num_tokens=padded_num_tokens_across_dp, + num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, with_prefill=with_prefill, + reserved_mc2_mask=self.reserved_mc2_mask, + moe_comm_method=moe_comm_method(self.device, self.dtype, + self.model_config.hf_config), num_actual_tokens=total_num_scheduled_tokens): with ProfileExecuteDuration().capture_async("forward"): self.maybe_setup_kv_connector(scheduler_output) @@ -1865,6 +1921,7 @@ def _dummy_run( skip_attn: bool = True, with_prefill: bool = False, is_torchair_compile: bool = False, + moe_comm_method: Type[MoECommMethod] = DummyCommImpl, ) -> torch.Tensor: # Padding for DP (num_tokens, num_tokens_across_dp, with_prefill, @@ -1932,6 +1989,9 @@ def _dummy_run( num_tokens_across_dp=num_tokens_across_dp, with_prefill=with_prefill, in_profile_run=self.in_profile_run, + reserved_mc2_mask=self.reserved_mc2_mask, + moe_comm_method=moe_comm_method( + self.device, self.dtype, self.model_config.hf_config), num_actual_tokens=0, ): hidden_states = self._generate_dummy_run_hidden_states( @@ -2328,13 +2388,21 @@ def _capture_model(self): # Trigger ACL graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. - # TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode with graph_capture(device=self.device): + skip_attn = not self.vllm_config.compilation_config.full_cuda_graph for num_tokens in reversed(self.aclgraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens) - self._dummy_run(num_tokens) + self._dummy_run( + num_tokens, + skip_attn=skip_attn, + moe_comm_method=self.moe_comm_method, + ) + self._dummy_run( + num_tokens, + skip_attn=skip_attn, + moe_comm_method=self.moe_comm_method, + ) def capture_model(self) -> None: start_time = time.perf_counter()