diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index bc1309ebcc..f8dde632f3 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -28,9 +28,8 @@ import torch_npu from vllm.model_executor.layers.activation import SiluAndMul -from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ - TokenDispatcherWithAllGather +from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather NUM_EXPERTS = [8, 64] EP_SIZE = [1] @@ -209,7 +208,7 @@ def test_select_experts( dtype=torch.int32) custom_routing_function.return_value = (mock_weights, mock_ids) - with patch("vllm_ascend.ops.layers.experts_selector._native_grouped_topk" + with patch("vllm_ascend.ops.moe.experts_selector._native_grouped_topk" ) as mock_native_grouped_topk: mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like( x) diff --git a/tests/e2e/singlecard/ops/test_moe_comm.py b/tests/e2e/singlecard/ops/test_moe_comm.py deleted file mode 100644 index b034ed4b5b..0000000000 --- a/tests/e2e/singlecard/ops/test_moe_comm.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. - -import gc -from types import SimpleNamespace - -import pytest -import torch - -from vllm.model_executor.layers.fused_moe.config import ( # isort: skip - FusedMoEConfig, FusedMoEParallelConfig) - -from vllm_ascend.distributed.moe_comm_method import ( # isort: skip - AllGatherCommImpl, NativeAllGatherCommImpl) - - -@pytest.mark.parametrize("num_tokens", [16, 128]) -@pytest.mark.parametrize("hidden_size", [64, 128]) -@pytest.mark.parametrize("global_num_experts", [8, 16]) -@pytest.mark.parametrize("num_local_experts", [4, 8]) -@pytest.mark.parametrize("top_k_num", [2, 4]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("ep_rank", [0, 1]) -@pytest.mark.parametrize("apply_a8_quantization", [False]) -def test_all_gather_comm_impl( - num_tokens, - hidden_size, - global_num_experts, - num_local_experts, - top_k_num, - dtype, - ep_rank, - apply_a8_quantization, - mocker, -): - """ - Tests the AllGatherCommImpl against the NativeAllGatherCommImpl. - - This test compares the outputs of the NPU-optimized AllGatherCommImpl - with a native PyTorch implementation (NativeAllGatherCommImpl) to ensure - correctness across various configurations. - """ - if top_k_num > global_num_experts: - pytest.skip("top_k_num cannot be greater than global_num_experts") - if num_local_experts > global_num_experts: - pytest.skip( - "num_local_experts cannot be greater than global_num_experts") - - device = torch.device("npu") - - # mock get_tensor_model_parallel_rank to return ep_rank - mocker.patch( - "vllm.model_executor.layers.fused_moe.config.get_tensor_model_parallel_rank", - return_value=ep_rank, - ) - - # make moe config - parallel_config = SimpleNamespace( - enable_expert_parallel=num_local_experts < global_num_experts) - moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( - tp_size_=max(2, global_num_experts // num_local_experts), - dp_size_=1, - vllm_parallel_config=parallel_config, - ) - - moe_config = FusedMoEConfig( - num_experts=global_num_experts, - experts_per_token=top_k_num, - hidden_dim=hidden_size, - num_local_experts=num_local_experts, - moe_parallel_config=moe_parallel_config, - in_dtype=dtype, - quant_config=None, # No quantization in this test - max_num_tokens=num_tokens, - ) - - # Instantiate implementations - native_impl = NativeAllGatherCommImpl(moe_config) - - all_gather_impl = AllGatherCommImpl(moe_config) - - # --- Input Data --- - hidden_states = torch.randn(num_tokens, - hidden_size, - device=device, - dtype=dtype) - topk_ids = torch.randint(0, - global_num_experts, (num_tokens, top_k_num), - device=device, - dtype=torch.int32) - topk_weights = torch.rand(num_tokens, top_k_num, device=device).to(dtype) - topk_weights = torch.nn.functional.softmax(topk_weights, dim=1) - - num_experts = global_num_experts - - expert_map = None - if num_local_experts < global_num_experts: - # Create a map where some experts are local and some are not - expert_map = torch.full((global_num_experts, ), -1, device=device) - expert_map[ep_rank * num_local_experts:(ep_rank + 1) * - num_local_experts] = torch.arange(num_local_experts, - device=device) - num_experts = num_local_experts - - # --- Run Native Implementation (Golden Reference) --- - native_hidden_states_out = hidden_states.clone() - ( - native_permuted_hidden, - native_expert_tokens, - _, - _, - ) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map, - num_experts, apply_a8_quantization) - # Simulate MLP output - native_mlp_output = torch.randn_like(native_permuted_hidden) - native_impl.unpermute(native_mlp_output, native_hidden_states_out) - - # --- Run AllGather Implementation --- - all_gather_hidden_states_out = hidden_states.clone() - ( - all_gather_permuted_hidden, - all_gather_expert_tokens, - _, - _, - ) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights, - expert_map, num_experts, apply_a8_quantization) - - # Use the same simulated MLP output for a fair comparison - all_gather_mlp_output = native_mlp_output.clone() - - all_gather_impl.unpermute(all_gather_mlp_output, - all_gather_hidden_states_out) - - # --- Assertions --- - # Define tolerance based on dtype - atol = 1e-3 if dtype == torch.float16 else 1e-2 - rtol = 1e-3 if dtype == torch.float16 else 1e-2 - - # 1. Compare expert_tokens from pre_process - assert torch.allclose(native_expert_tokens.to( - all_gather_expert_tokens.device), - all_gather_expert_tokens, - atol=atol, - rtol=rtol), "Expert tokens do not match." - - # 2. Compare permuted_hidden_states from pre_process - num_valid_tokens = native_expert_tokens.sum() - assert torch.allclose(native_permuted_hidden[:num_valid_tokens].to( - all_gather_permuted_hidden.device), - all_gather_permuted_hidden[:num_valid_tokens], - atol=atol, - rtol=rtol), "Permuted hidden states do not match." - - # 3. Compare final hidden_states from post_process - assert torch.allclose(native_hidden_states_out.to( - all_gather_hidden_states_out.device), - all_gather_hidden_states_out, - atol=atol, - rtol=rtol), "Final hidden states do not match." - gc.collect() - torch.npu.empty_cache() - torch.npu.reset_peak_memory_stats() diff --git a/tests/ut/ops/test_fused_moe_prepare_and_finalize.py b/tests/ut/ops/test_fused_moe_prepare_and_finalize.py new file mode 100644 index 0000000000..f0c5ff8278 --- /dev/null +++ b/tests/ut/ops/test_fused_moe_prepare_and_finalize.py @@ -0,0 +1,218 @@ +import unittest +from unittest.mock import MagicMock, patch + +import torch +from vllm.model_executor.layers.fused_moe import FusedMoEConfig + +from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import ( + FusedMoEPrepareAndFinalizeWithAll2All, + FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2) + + +class TestFusedMoEPrepareAndFinalize(unittest.TestCase): + + def setUp(self): + # Mock FusedMoEConfig + self.moe_config = MagicMock(spec=FusedMoEConfig) + self.moe_config.tp_group = MagicMock() + self.moe_config.tp_group.device_group = MagicMock() + self.moe_config.dp_size = 1 + self.moe_config.tp_size = 1 + self.moe_config.ep_size = 1 + self.moe_config.dp_group = MagicMock() + + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", + return_value=1) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", + return_value=0) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context" + ) + def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank, + mock_tp_size): + mock_context = MagicMock() + mock_context.mc2_mask = torch.tensor([1, 0, 1]) + mock_context.padded_num_tokens = 4 + mock_get_forward_context.return_value = mock_context + + layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config) + + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 2) + + h_out, r_out, mask = layer.prepare(hidden_states, router_logits) + + # Check padding and split + self.assertEqual(h_out.shape[0], 4) + self.assertEqual(r_out.shape[0], 4) + self.assertEqual(mask.tolist(), [1, 0, 1]) + + # Finalize + result = layer.finalize(h_out, reduce_results=False) + self.assertEqual(result.shape[0], 3) + + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", + return_value=2) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", + return_value=0) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context" + ) + @patch("torch.distributed.all_gather") + def test_mc2_tp_split_allgather(self, mock_all_gather, + mock_get_forward_context, mock_tp_rank, + mock_tp_size): + mock_context = MagicMock() + mock_context.mc2_mask = torch.tensor([1, 0, 1, 0]) + mock_context.padded_num_tokens = 4 + mock_get_forward_context.return_value = mock_context + + layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config) + hidden_states = torch.randn(4, 8) + router_logits = torch.randn(4, 2) + + h_out, r_out, mask = layer.prepare(hidden_states, + router_logits, + enable_shared_expert_dp=False, + replace_allreduce=False) + + # With TP=2, should split into 2 parts + self.assertEqual(h_out.shape[0], 2) + + # Mock all_gather behavior + def mock_all_gather_func(tensor_list, tensor, group=None): + tensor_list[0] = tensor + tensor_list[1] = tensor.clone() + + mock_all_gather.side_effect = mock_all_gather_func + + layer.split_hidden_states = [ + torch.zeros_like(h_out), + torch.zeros_like(h_out) + ] + final_result = layer.finalize(h_out, reduce_results=False) + + # Should concat back to original size + self.assertEqual(final_result.shape[0], 4) + + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", + return_value=1) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", + return_value=0) + def test_all2all_prepare_finalize(self, mock_tp_rank, mock_tp_size): + layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config) + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 2) + + h_out, r_out, _ = layer.prepare(hidden_states, router_logits) + + # Pad to tp_size=1, so no change + self.assertEqual(h_out.shape[0], 3) + + result = layer.finalize(h_out, reduce_results=False) + self.assertEqual(result.shape[0], 3) + + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", + return_value=2) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", + return_value=0) + @patch("torch.distributed.all_gather") + def test_all2all_tp_split_allgather(self, mock_all_gather, mock_tp_rank, + mock_tp_size): + layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config) + hidden_states = torch.randn(2, 8) + router_logits = torch.randn(2, 2) + + h_out, r_out, _ = layer.prepare(hidden_states, + router_logits, + enable_shared_expert_dp=False, + replace_allreduce=False) + + # Split due to TP=2 + self.assertEqual(h_out.shape[0], 1) + + # Mock all_gather + def mock_all_gather_func(tensor_list, tensor, group=None): + tensor_list[0] = tensor + tensor_list[1] = tensor.clone() + + mock_all_gather.side_effect = mock_all_gather_func + + layer.split_hidden_states = [ + torch.zeros_like(h_out), + torch.zeros_like(h_out) + ] + final_result = layer.finalize(h_out, reduce_results=False) + + # Should concat back + self.assertEqual(final_result.shape[0], 2) + + @patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group") + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce" + ) + @patch( + "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context" + ) + def test_allgather_prepare_finalize(self, mock_get_forward_context, + mock_tp_all_reduce, mock_get_dp_group): + # Mock forward context + mock_context = MagicMock() + mock_context.max_tokens_across_dp = 6 + mock_get_forward_context.return_value = mock_context + + # Create a proper mock for DP group with working all_gather + mock_dp_group = MagicMock() + + def mock_all_gather_func(tensor, dim): + # Simulate DP=2: repeat the tensor along the specified dimension + return torch.cat([tensor, tensor], dim=dim) + + mock_dp_group.all_gather = mock_all_gather_func + mock_get_dp_group.return_value = mock_dp_group + + self.moe_config.dp_size = 2 + self.moe_config.tp_size = 1 + self.moe_config.ep_size = 1 + self.moe_config.dp_group = mock_dp_group + + layer = FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config) + + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 2) + + # Mock the gate function for rm_router_logits=False case + mock_gate = MagicMock() + mock_gate.return_value = (router_logits.repeat(2, 1), None) + + h_out, r_out, _ = layer.prepare(hidden_states, + router_logits, + rm_router_logits=False, + gate=mock_gate) + + # After all-gather with DP=2, should double the batch size + self.assertEqual(h_out.shape[0], 12) + self.assertEqual(r_out.shape[0], 12) + + # Finalize with reduce_scatter + def mock_reduce_scatter_func(tensor, dim): + # Simulate reduce_scatter: take first half + return tensor[:3] + + mock_dp_group.reduce_scatter = mock_reduce_scatter_func + result = layer.finalize(h_out, reduce_results=False) + + self.assertEqual(result.shape[0], 3) + + # Test with TP all-reduce + mock_tp_all_reduce.return_value = result + result_with_tp = layer.finalize(h_out, reduce_results=True) + self.assertEqual(result_with_tp.shape[0], 3) diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 58cba6d8e5..3e9351af01 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -22,14 +22,14 @@ from pytest_mock import MockerFixture from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase -import vllm_ascend.ops.moe_dispatcher.token_dispatcher as token_dispatcher_module +import vllm_ascend.ops.moe.token_dispatcher as token_dispatcher_module from tests.ut.base import TestBase from vllm_ascend.ascend_forward_context import (FusedMoEState, _get_fused_moe_state) from vllm_ascend.ops.fused_moe import (AscendFusedMoE, AscendUnquantizedFusedMoEMethod) -from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.ops.layers.moe_mlp import cumsum_group_list, unified_apply_mlp +from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp from vllm_ascend.utils import AscendSocVersion, adapt_patch adapt_patch(True) @@ -110,11 +110,11 @@ def capture_register(dispatcher_instance): captured_dispatchers[key] = mock_token_dispatcher_with_mc2 mock_register_token_dispatcher_patcher = patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher', + 'vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher', side_effect=capture_register) mock_get_token_dispatcher_patcher = patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_token_dispatcher', + 'vllm_ascend.ops.moe.token_dispatcher.get_token_dispatcher', side_effect=lambda name: captured_dispatchers.get(name)) default_mock_token_dispatcher = mock_token_dispatcher_with_allgather @@ -158,7 +158,7 @@ def capture_register(dispatcher_instance): )), \ patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \ patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \ - patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context', + patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context', return_value=mock_forward_context_obj): yield { @@ -562,8 +562,8 @@ def test_cumsum_group_list_with_type_2(self): class TestUnifiedApplyMLP(TestBase): - @patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context') - @patch('vllm_ascend.ops.layers.moe_mlp.is_310p') + @patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context') + @patch('vllm_ascend.ops.moe.moe_mlp.is_310p') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_dynamic_quant') @patch('torch_npu.npu_dequant_swiglu_quant') @@ -629,7 +629,7 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant, self.assertEqual(result.dtype, torch.bfloat16) - @patch('vllm_ascend.ops.layers.moe_mlp.is_310p') + @patch('vllm_ascend.ops.moe.moe_mlp.is_310p') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') @@ -671,7 +671,7 @@ def test_unified_apply_mlp_without_quantization(self, self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.dtype, torch.float16) - @patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context') + @patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') @@ -731,7 +731,7 @@ def test_unified_apply_mlp_with_quantization_and_dynamic_scale( self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.dtype, torch.bfloat16) - @patch('vllm_ascend.ops.layers.moe_mlp.is_310p') + @patch('vllm_ascend.ops.moe.moe_mlp.is_310p') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') @@ -776,7 +776,7 @@ def test_unified_apply_mlp_without_quantization_310p( self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.dtype, torch.float16) - @patch("vllm_ascend.ops.layers.moe_mlp.get_forward_context") + @patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context") @patch("torch_npu.npu_grouped_matmul") @patch("torch_npu.npu_swiglu") @patch("torch_npu.npu_grouped_matmul_swiglu_quant") diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py new file mode 100644 index 0000000000..4bd0e103ec --- /dev/null +++ b/tests/ut/ops/test_moe_comm_method.py @@ -0,0 +1,212 @@ +from unittest.mock import MagicMock, patch + +import torch +from vllm.model_executor.layers.fused_moe import FusedMoEConfig + +from tests.ut.base import TestBase +from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl, + AlltoAllCommImpl, MC2CommImpl) + + +class TestMoECommMethod(TestBase): + + def setUp(self): + # Mock FusedMoEConfig + self.moe_config = MagicMock(spec=FusedMoEConfig) + self.moe_config.num_experts = 8 + self.moe_config.num_local_experts = 2 + self.moe_config.experts_per_token = 2 + self.moe_config.tp_group = MagicMock() + self.moe_config.tp_group.device_group = MagicMock() + self.moe_config.dp_size = 1 + self.moe_config.tp_size = 1 + self.moe_config.ep_size = 1 + self.moe_config.dp_group = MagicMock() + self.moe_config.num_global_redundant_experts = 0 + + @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") + @patch( + "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather" + ) + @patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather") + def test_all_gather_comm_impl(self, mock_token_dispatcher, + mock_prepare_finalize, + mock_get_forward_context): + # Mock forward context + mock_context = MagicMock() + mock_context.moe_comm_method = "all_gather" + mock_get_forward_context.return_value = mock_context + + # Mock prepare finalize + mock_pf_instance = MagicMock() + mock_pf_instance.prepare.return_value = (torch.randn(4, 8), + torch.randn(4, 2), None) + mock_pf_instance.finalize.return_value = torch.randn(4, 8) + mock_prepare_finalize.return_value = mock_pf_instance + + # Mock token dispatcher + mock_td_instance = MagicMock() + mock_token_dispatcher.return_value = mock_td_instance + + # Create instance + comm_impl = AllGatherCommImpl(self.moe_config) + + # Test prepare method + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 2) + h_out, r_out = comm_impl.prepare(hidden_states, router_logits) + + # Verify prepare was called with correct arguments + mock_pf_instance.prepare.assert_called_once_with( + hidden_states, router_logits, False, False, False, None) + + # Test finalize method + comm_impl.finalize(h_out, reduce_results=True) + mock_pf_instance.finalize.assert_called_once_with(h_out, True) + + @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") + @patch( + "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2" + ) + @patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2") + def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize, + mock_get_forward_context): + # Mock forward context + mock_context = MagicMock() + mock_context.moe_comm_method = "mc2" + mock_get_forward_context.return_value = mock_context + + # Mock prepare finalize + mock_pf_instance = MagicMock() + mock_pf_instance.prepare.return_value = (torch.randn(4, 8), + torch.randn(4, 2), + torch.tensor([1, 0, 1, 0])) + mock_pf_instance.finalize.return_value = torch.randn(4, 8) + mock_prepare_finalize.return_value = mock_pf_instance + + # Mock token dispatcher + mock_td_instance = MagicMock() + mock_token_dispatcher.return_value = mock_td_instance + + # Create instance + comm_impl = MC2CommImpl(self.moe_config) + + # Test prepare method + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 2) + h_out, r_out = comm_impl.prepare(hidden_states, router_logits) + + # Verify prepare was called with correct arguments + mock_pf_instance.prepare.assert_called_once_with( + hidden_states, router_logits, False, False, False, None) + + # Test finalize method + comm_impl.finalize(h_out, reduce_results=True) + mock_pf_instance.finalize.assert_called_once_with(h_out, True) + + @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") + @patch( + "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All" + ) + @patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV") + def test_alltoall_comm_impl(self, mock_token_dispatcher, + mock_prepare_finalize, + mock_get_forward_context): + # Mock forward context + mock_context = MagicMock() + mock_context.moe_comm_method = "alltoall" + mock_get_forward_context.return_value = mock_context + + # Mock prepare finalize + mock_pf_instance = MagicMock() + mock_pf_instance.prepare.return_value = (torch.randn(4, 8), + torch.randn(4, 2), None) + mock_pf_instance.finalize.return_value = torch.randn(4, 8) + mock_prepare_finalize.return_value = mock_pf_instance + + # Mock token dispatcher + mock_td_instance = MagicMock() + mock_token_dispatcher.return_value = mock_td_instance + + # Create instance + comm_impl = AlltoAllCommImpl(self.moe_config) + + # Test prepare method + hidden_states = torch.randn(3, 8) + router_logits = torch.randn(3, 2) + h_out, r_out = comm_impl.prepare(hidden_states, router_logits) + + # Verify prepare was called with correct arguments + mock_pf_instance.prepare.assert_called_once_with( + hidden_states, router_logits, False, False, False, None) + + @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") + @patch( + "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather" + ) + @patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather") + @patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp") + def test_fused_experts_method(self, mock_unified_apply_mlp, + mock_token_dispatcher, mock_prepare_finalize, + mock_get_forward_context): + # Mock forward context + mock_context = MagicMock() + mock_context.moe_comm_method = "all_gather" + mock_get_forward_context.return_value = mock_context + + # Mock prepare finalize + mock_pf_instance = MagicMock() + mock_pf_instance.prepare.return_value = (torch.randn(4, 8), + torch.randn(4, 2), None) + mock_pf_instance.finalize.return_value = torch.randn(4, 8) + mock_prepare_finalize.return_value = mock_pf_instance + + # Mock token dispatcher + mock_td_instance = MagicMock() + mock_td_instance.token_dispatch.return_value = { + "hidden_states": torch.randn(6, 8), + "group_list": torch.tensor([2, 2, 2]), + "group_list_type": 1 + } + mock_td_instance.token_combine.return_value = torch.randn(4, 8) + mock_token_dispatcher.return_value = mock_td_instance + + # Mock unified_apply_mlp + mock_unified_apply_mlp.return_value = torch.randn(6, 8) + + # Create instance + comm_impl = AllGatherCommImpl(self.moe_config) + + # Test fused_experts method + hidden_states = torch.randn(4, 8).contiguous() + w1 = torch.randn(16, 8).contiguous() + w2 = torch.randn(16, 8).contiguous() + topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2], + [0.6, 0.4]]) + topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]]) + row_idx = torch.arange(4) + + # Make sure tensors are contiguous and have correct strides + hidden_states = hidden_states.contiguous() + w1 = w1.contiguous() + w2 = w2.contiguous() + + result = comm_impl.fused_experts(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + activation="silu") + + # Verify result shape + self.assertEqual(result.shape, (4, 8)) + + # Verify token_dispatch was called + mock_td_instance.token_dispatch.assert_called_once() + + # Verify unified_apply_mlp was called + mock_unified_apply_mlp.assert_called_once() + + # Verify token_combine was called + mock_td_instance.token_combine.assert_called_once() diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 6782f455c8..53b2fa9d1e 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -20,7 +20,8 @@ import torch from tests.ut.base import TestBase -from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( + +from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip AscendSocVersion, TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers, _register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers) @@ -34,7 +35,7 @@ def setUp(self): self.mc2_group.rank_in_group = 0 self.mc2_group.world_size = 8 self.mc2_group_patch = patch( - "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group", + "vllm_ascend.ops.moe.token_dispatcher.get_mc2_group", return_value=self.mc2_group) self.mc2_group_patch.start() @@ -52,7 +53,7 @@ def setUp(self): # Mock get_ascend_soc_version() self.ascend_soc_version_patch = patch( - "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version", + "vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=AscendSocVersion.A3) self.ascend_soc_version_patch.start() @@ -329,7 +330,7 @@ def setUp(self): # Mock gather_from_sequence_parallel_region patcher7 = patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.gather_from_sequence_parallel_region' + 'vllm_ascend.ops.moe.token_dispatcher.gather_from_sequence_parallel_region' ) self.mock_gather_from_sequence_parallel_region = patcher7.start() self.addCleanup(patcher7.stop) @@ -518,12 +519,8 @@ def test_register_and_get_token_dispatcher(self): self.assertIsNone(get_token_dispatcher("NonExistentDispatcher")) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather' - ) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher' - ) + @patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAllGather') + @patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher') def test_setup_token_dispatchers_ep_size_1_creates_allgather( self, mock_register, mock_allgather_class): kwargs = {"top_k": 2, "num_experts": 8} @@ -537,12 +534,8 @@ def test_setup_token_dispatchers_ep_size_1_creates_allgather( mock_allgather_class.assert_called_once_with(**kwargs) mock_register.assert_called_once_with(mock_instance) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV' - ) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher' - ) + @patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV') + @patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher') def test_setup_token_dispatchers_ep_size_2_creates_all2allv( self, mock_register, mock_all2allv_class): kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2} @@ -556,15 +549,9 @@ def test_setup_token_dispatchers_ep_size_2_creates_all2allv( mock_all2allv_class.assert_called_once_with(**kwargs) mock_register.assert_called_once_with(mock_instance) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV' - ) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2' - ) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher' - ) + @patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV') + @patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2') + @patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher') def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2( self, mock_register, mock_mc2_class, mock_all2allv_class): kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2} @@ -584,15 +571,9 @@ def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2( mock_register.assert_any_call(mock_all2allv_instance) mock_register.assert_any_call(mock_mc2_instance) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV' - ) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2' - ) - @patch( - 'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher' - ) + @patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV') + @patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2') + @patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher') def test_setup_token_dispatchers_ep_size_16_skips_if_exist( self, mock_register, mock_mc2_class, mock_all2allv_class): kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2} diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 90a5f59b06..3f2557bebe 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -5,8 +5,8 @@ from tests.ut.base import TestBase from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.ops.layers.experts_selector import (_native_grouped_topk, - select_experts) +from vllm_ascend.ops.moe.experts_selector import (_native_grouped_topk, + select_experts) from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, AscendW8A8LinearMethod, @@ -784,7 +784,7 @@ def test_grouped_topk(self, mock_topk): self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.dtype, torch.int32) - @patch('vllm_ascend.ops.layers.experts_selector._native_grouped_topk') + @patch('vllm_ascend.ops.moe.experts_selector._native_grouped_topk') def test_grouped_topk_with_correction_bias(self, mock_grouped_topk): """Test grouped topk with expert score correction bias""" mock_grouped_topk.return_value = torch.ones(self.num_tokens, diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 31822af018..71ae4d07e6 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -94,8 +94,7 @@ def set_ascend_forward_context( forward_context.fused_moe_state = fused_moe_state forward_context.in_profile_run = in_profile_run - from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ - get_token_dispatcher + from vllm_ascend.ops.moe.token_dispatcher import get_token_dispatcher dispatcher_name = get_dispatcher_name(ep_size, with_prefill) dispatcher = get_token_dispatcher(dispatcher_name) forward_context.token_dispatcher = dispatcher diff --git a/vllm_ascend/distributed/moe_comm_method.py b/vllm_ascend/distributed/moe_comm_method.py deleted file mode 100644 index 5c6d8c67fb..0000000000 --- a/vllm_ascend/distributed/moe_comm_method.py +++ /dev/null @@ -1,555 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional - -import torch -import torch.distributed as dist -import torch.nn as nn -import torch_npu -from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.distributed.parallel_state import ( - get_dp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fused_moe import FusedMoEConfig - -from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version - - -class MoECommMethod(ABC): - """Base class for MoE communication methods.""" - - def __init__(self, moe_config: FusedMoEConfig): - self.moe_config = moe_config - - @abstractmethod - def prepare( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Prepare the MoE communication method. - - This method is called before quant_method.apply to prepare the - communication method. It can be used to initialize any necessary - resources or configurations. - """ - pass - - @abstractmethod - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: - """Finalize the MoE communication method. - - This method is called after quant_method.apply to finalize the - communication method. It can be used to clean up any resources or - configurations. - """ - pass - - @abstractmethod - def permute( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, - num_experts: int, - apply_a8_quantization: bool, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: - """Pre-process before MLP. - - Args: - hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size) - topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num) - topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num) - expert_map (torch.Tensor): Tensor of shape (global_num_experts, ) - Mapping from global expert IDs to local expert IDs. - num_experts (int): Number of local experts (experts on this device). - apply_a8_quantization (bool): Whether to apply A8 quantization (W4A8 and W8A8). - - Returns: - tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing: - - permuted_hidden_states (torch.Tensor): Tensor of shape - (num_tokens * top_k_num, hidden_size) after permuting - hidden_states based on topk_ids. - - expert_tokens (torch.Tensor): Tensor of shape (num_experts, ) - Number of tokens assigned to each expert. - - dynamic_scale (torch.Tensor, optional): Tensor of shape (num_experts, ) - Dynamic scale for each expert, used for quantization. - - group_list_type (int): Type of group list, 0 for `cumsum` - and 1 for `count`. This is mainly for `npu_grouped_matmul` - to determine how to handle the output. - Raises: - NotImplementedError: If the method is not implemented in the subclass. - """ - pass - - @abstractmethod - def unpermute(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - """Post-process after MLP. - - Args: - mlp_output (torch.Tensor): Tensor of shape - (num_tokens * top_k_num, hidden_size) after MLP. - hidden_states (torch.Tensor): Tensor of shape - (num_tokens, hidden_size) to be updated with the final output. - """ - pass - - -class AllGatherCommImpl(MoECommMethod): - """This implementation is the same as NativeAllGatherCommImpl, - but uses NPU-specific ops for better performance. - - This implementation should be compatible with all scenarios, and - thus it is the default implementation for MoE communication methods. - It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing - and `torch_npu.npu_moe_token_unpermute` for post-processing - to handle the token-to-expert mapping and communication efficiently. - - NOTE(Yizhou): TBH, it is really weird that we were supposed to use - `torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing` - or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute` - for pre-processing and post-processing, respectively. - But `npu_moe_finalize_routing` will lead to accuracy issues so we have to - use `torch_npu.npu_moe_token_unpermute` instead. - This is a workaround and should be removed after the issue is fixed. - """ - - def prepare( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """When DP size > 1, pad the hidden states and router logits for communication.""" - if self.moe_config.dp_size > 1: - forward_context = get_forward_context() - max_tokens_across_dp = forward_context.max_tokens_across_dp - - self.num_tokens = hidden_states.shape[0] - pad_size = max_tokens_across_dp - self.num_tokens - if pad_size > 0: - hidden_states = nn.functional.pad(hidden_states, - (0, 0, 0, pad_size)) - router_logits = nn.functional.pad(router_logits, - (0, 0, 0, pad_size)) - - hidden_states = self.moe_config.dp_group.all_gather( - hidden_states, 0) - router_logits = self.moe_config.dp_group.all_gather( - router_logits, 0) - - return hidden_states, router_logits - - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: - """When DP size > 1, reduce-scatter the hidden states to get the final output. - - When TP size > 1, all-reduce the hidden states to get the final output. - """ - if self.moe_config.dp_size > 1: - hidden_states = get_dp_group().reduce_scatter(hidden_states, 0) - hidden_states = hidden_states[:self.num_tokens] - - if reduce_results and (self.moe_config.tp_size > 1 - or self.moe_config.ep_size > 1): - hidden_states = tensor_model_parallel_all_reduce(hidden_states) - - return hidden_states - - def permute( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, # noqa: F841 - num_experts: int, - apply_a8_quantization: bool, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: - num_tokens = hidden_states.shape[0] - - self.topk_weights = topk_weights - self.topk_ids = topk_ids - - first_expert_idx = 0 - if expert_map is not None: - # FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...] - # So we need to filter out invalid tokens by zeroing their weights. - # This is a workaround and should be removed after the issue is fixed - mask = expert_map[topk_ids] != -1 - # NOTE: This is equivalent to self.topk_weights[~mask] = 0.0, - # but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph - self.topk_weights = torch.where(mask, topk_weights, 0.0) - - first_expert_idx = self.moe_config.ep_rank * num_experts - last_expert_idx = first_expert_idx + num_experts - - permuted_hidden_states, expanded_row_idx, expert_tokens, _ = ( - torch_npu.npu_moe_init_routing_v2( - hidden_states, - topk_ids, - active_num=num_tokens * self.moe_config.experts_per_token, - expert_num=self.moe_config.num_experts, - expert_tokens_num_type=1, # Only support `count` mode now - expert_tokens_num_flag=True, # Output `expert_tokens` - active_expert_range=[first_expert_idx, last_expert_idx], - quant_mode=-1, - )) - self.expanded_row_idx = expanded_row_idx - permuted_hidden_states = permuted_hidden_states - - group_list_type = 1 # `count` mode - - return permuted_hidden_states, expert_tokens, None, group_list_type - - def unpermute(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - hidden_states[:] = torch_npu.npu_moe_token_unpermute( - permuted_tokens=mlp_output, - sorted_indices=self.expanded_row_idx, - probs=self.topk_weights) - - -class NativeAllGatherCommImpl(AllGatherCommImpl): - """This implementation should be compatible with all scenarios. - - Note that this implementation purely consists of native PyTorch ops - and does not use any NPU-specific ops. So the performance may not be optimal. - But it is a good fallback for scenarios where NPU-specific ops are not available. - """ - - def permute( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, - num_experts: int, - apply_a8_quantization: bool, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: - num_tokens = hidden_states.shape[0] - - # Generate token indices and flatten - token_indices = torch.arange(num_tokens, - device=hidden_states.device, - dtype=torch.int64) - token_indices = (token_indices.unsqueeze(1).expand( - -1, self.moe_config.experts_per_token).reshape(-1)) - - # Flatten token-to-expert mappings and map to local experts - weights_flat = topk_weights.view(-1) - experts_flat = topk_ids.view(-1) - local_experts_flat = (expert_map[experts_flat] - if expert_map is not None else experts_flat) - - # Filter valid token-expert pairs - mask = local_experts_flat != -1 - # FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...] - # So we need to filter out invalid tokens by zeroing their weights. - # This is a workaround and should be removed after the issue is fixed - filtered_weights = torch.where(mask, weights_flat, - torch.zeros_like(weights_flat)).to( - topk_weights.dtype) - filtered_experts = torch.where( - mask, - local_experts_flat, - torch.full_like(local_experts_flat, num_experts), - ).to(topk_ids.dtype) - - # Sort by local expert IDs - sort_indices = torch.argsort(filtered_experts.view(torch.float32)) - self.sorted_token_indices = token_indices[sort_indices] - self.sorted_weights = filtered_weights[sort_indices] - - # Compute token counts with minlength of num_experts - # This is equivalent to but faster than: - # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] - token_counts = torch.zeros(num_experts + 1, - device=hidden_states.device, - dtype=torch.int64) - ones = torch.ones_like(filtered_experts, dtype=torch.int64) - token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) - expert_tokens = token_counts[:num_experts] - - # Rearrange hidden_states - permuted_hidden_states = hidden_states[self.sorted_token_indices] - - group_list_type = 1 # `count` mode - - return permuted_hidden_states, expert_tokens, None, group_list_type - - def unpermute(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - mlp_output = mlp_output * self.sorted_weights.unsqueeze(1) - - final_hidden_states = torch.zeros_like(hidden_states) - final_hidden_states.index_add_(0, self.sorted_token_indices, - mlp_output) - - hidden_states[:] = final_hidden_states - - -class MC2CommImpl(MoECommMethod): - """This implementation is for the scenarios listed below: - 1. `enable_expert_parallel=True`. - 2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available. - 3. `enable_expert_parallel=False` is not supported. - - This implementation uses the MC2 communication method, which is optimized for - Communication and Computation parallelism on Ascend devices. - """ - - def __init__(self, moe_config: Optional[FusedMoEConfig]): - super().__init__(moe_config) - - # NOTE: We do not need to use mc2_group's rank and world size - # because ep_group and mc2_group basically have the same init params. - # We only init another group because of the restriction of MC2: - # "No other groups can be used in the same process as the MC2 group." - self.mc2_comm_name = get_mc2_group().device_group._get_backend( - torch.device("npu")).get_hccl_comm_name(self.moe_config.ep_rank) - - # Feature flags - self.enable_dispatch_v2 = hasattr(torch_npu, - "npu_moe_distribute_dispatch_v2") - self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3 - self.need_extra_args = self.is_ascend_a3 - self._restore_tp_across_dp() - - def _restore_tp_across_dp(self): - # NOTE: Since vLLM flatten tp across dp, we need to restore the original - # tp_size and tp_rank. - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - def prepare( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """The target_pad_length is calculated in forward_context, here we pad the - hidden states and router logits. And if TP size > 1, we also need to split - the tensors accordingly. - """ - self.num_tokens, _ = hidden_states.shape - forward_context = get_forward_context() - self.mc2_mask = forward_context.mc2_mask - target_pad_length = forward_context.padded_num_tokens - pad_size = target_pad_length - self.num_tokens - - if pad_size > 0: - hidden_states = nn.functional.pad(hidden_states, - (0, 0, 0, pad_size)) - router_logits = nn.functional.pad(router_logits, - (0, 0, 0, pad_size)) - - if self.tp_size > 1: - split_hidden_states = torch.tensor_split(hidden_states, - self.tp_size, - dim=0) - split_router_logits = torch.tensor_split(router_logits, - self.tp_size, - dim=0) - split_mc2_mask = torch.tensor_split(self.mc2_mask, - self.tp_size, - dim=0) - self.split_hidden_states = split_hidden_states - - hidden_states = split_hidden_states[self.tp_rank] - router_logits = split_router_logits[self.tp_rank] - self.mc2_mask = split_mc2_mask[self.tp_rank] - - return hidden_states, router_logits - - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: - """If TP size > 1, all-gather the hidden states to get the final output. - - Also, unpad the hidden states if needed. - """ - if self.tp_size > 1: - dist.all_gather(list(self.split_hidden_states), hidden_states, - self.moe_config.tp_group.device_group) - hidden_states = torch.cat(self.split_hidden_states, dim=0) - - if self.num_tokens < hidden_states.shape[0]: - hidden_states = hidden_states[:self.num_tokens] - - return hidden_states - - def permute( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, - num_experts: int, - apply_a8_quantization: bool, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: - # Store tensors needed for post_process - self.topk_ids = topk_ids - self.topk_weights = topk_weights.to(torch.float32) - - dispatch_kwargs = { - "x": hidden_states, - "expert_ids": self.topk_ids, - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": self.moe_config.num_experts, - "global_bs": 0, - "scales": None, - "quant_mode": 2 if apply_a8_quantization else 0, - "group_ep": self.mc2_comm_name, - "ep_world_size": self.moe_config.ep_size, - "ep_rank_id": self.moe_config.ep_rank, - } - - if self.need_extra_args: - dispatch_kwargs.update({ - "group_tp": self.mc2_comm_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - if self.is_ascend_a3 and self.enable_dispatch_v2: - dispatch_kwargs.update({ - "x_active_mask": self.mc2_mask, - }) - - dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch - - ( - permuted_hidden_states, - dynamic_scale, - self.assist_info_for_combine, - expert_tokens, - self.ep_recv_counts, - self.tp_recv_counts, - ) = dispatch(**dispatch_kwargs)[:6] - - group_list_type = 1 - - return permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type - - def unpermute(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - combine_kwargs = { - "expand_x": mlp_output, - "expert_ids": self.topk_ids, - "expert_scales": self.topk_weights, - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": self.moe_config.num_experts, - "global_bs": 0, - "ep_send_counts": self.ep_recv_counts, - "group_ep": self.mc2_comm_name, - "ep_world_size": self.moe_config.ep_size, - "ep_rank_id": self.moe_config.ep_rank, - } - - if self.enable_dispatch_v2: - combine_kwargs[ - "assist_info_for_combine"] = self.assist_info_for_combine - else: - combine_kwargs["expand_idx"] = self.assist_info_for_combine - - if self.need_extra_args: - combine_kwargs.update({ - "tp_send_counts": self.tp_recv_counts, - "group_tp": self.mc2_comm_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - if self.is_ascend_a3 and self.enable_dispatch_v2: - combine_kwargs.update({ - "x_active_mask": self.mc2_mask, - }) - - combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine - - hidden_states[:] = combine(**combine_kwargs) - - -class AlltoAllCommImpl(MoECommMethod): - """This implementation is for the scenarios listed below: - 1. `enable_expert_parallel=True`. - 2. `npu_grouped_matmul` is available. - - This implementation uses all-to-all communication to exchange tokens - between data parallel ranks before and after the MLP computation. It should - have better performance than AllGatherCommImpl when DP size > 1. - """ - - def __init__(self, moe_config: Optional[FusedMoEConfig]): - super().__init__(moe_config) - from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ - get_token_dispatcher - self.token_dispatcher = get_token_dispatcher( - "TokenDispatcherWithAll2AllV") - self._restore_tp_across_dp() - - def _restore_tp_across_dp(self): - # NOTE: Since vLLM flatten tp across dp, we need to restore the original - # tp_size and tp_rank. - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - def prepare( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - self.num_tokens, _ = hidden_states.shape - pad_size = self.tp_size - self.num_tokens - - if pad_size > 0: - hidden_states = nn.functional.pad(hidden_states, - (0, 0, 0, pad_size)) - router_logits = nn.functional.pad(router_logits, - (0, 0, 0, pad_size)) - - if self.tp_size > 1: - split_hidden_states = torch.tensor_split(hidden_states, - self.tp_size, - dim=0) - split_router_logits = torch.tensor_split(router_logits, - self.tp_size, - dim=0) - self.split_hidden_states = split_hidden_states - - hidden_states = split_hidden_states[self.tp_rank] - router_logits = split_router_logits[self.tp_rank] - - return hidden_states, router_logits - - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: - """If TP size > 1, all-gather the hidden states to get the final output. - - Also, unpad the hidden states if needed. - """ - if self.tp_size > 1: - dist.all_gather(list(self.split_hidden_states), hidden_states, - self.moe_config.tp_group.device_group) - hidden_states = torch.cat(self.split_hidden_states, dim=0) - - if self.num_tokens < hidden_states.shape[0]: - hidden_states = hidden_states[:self.num_tokens] - - return hidden_states - - def permute( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, - num_experts: int, - apply_a8_quantization: bool, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: - results = self.token_dispatcher.token_dispatch( - hidden_states, - topk_weights, - topk_ids, - None, - log2phy=None, - with_quant=apply_a8_quantization) - return results["hidden_states"], results["group_list"], results[ - "dynamic_scale"], results["group_list_type"] - - def unpermute(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - hidden_states[:] = self.token_dispatcher.token_combine(mlp_output) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 5cb2d6fa5b..3142bc8323 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Callable, Optional +from typing import Callable, Optional import torch import torch_npu @@ -28,118 +28,16 @@ FusedMoE, UnquantizedFusedMoEMethod) from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, - AlltoAllCommImpl, - MC2CommImpl) from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ - setup_token_dispatchers +from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl, + AlltoAllCommImpl, MC2CommImpl) +from vllm_ascend.ops.moe.token_dispatcher import setup_token_dispatchers from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, vllm_version_is original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ -def fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_int8_w8a8: bool = False, - use_int4_w4a8: bool = False, - global_num_experts: Optional[int] = None, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, - # For TorchAir graph - is_torchair: bool = False, - # For Cube/Vector parallel - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, - # For load balance - log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, -) -> torch.Tensor: - # Check constraints - assert hidden_states.shape[1] == w1.shape[1], ( - f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}") - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.stride(-1) == 1, "Stride of last dimension must be 1" - assert w2.stride(-1) == 1, "Stride of last dimension must be 1" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - if (use_int8_w8a8 or use_int4_w4a8): - assert w1_scale is not None and w2_scale is not None, \ - "INT8 quantization requires weight scales." - - w1_scale = w1_scale.to(torch.float32) - down_scale = [w2_scale] - down_output_dtype = w2_scale.dtype - else: - down_scale = None - down_output_dtype = None - - moe_comm_method = get_forward_context().moe_comm_method - assert moe_comm_method is not None, "Missing communication context" - - num_experts = w1.shape[0] - - permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = moe_comm_method.permute( - hidden_states, topk_ids, topk_weights, expert_map, num_experts, - use_int8_w8a8 or use_int4_w4a8) - - gate_up_output = torch_npu.npu_grouped_matmul( - x=[permuted_hidden_states], - weight=[w1], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=expert_tokens, - output_dtype=torch.int32 if use_int8_w8a8 else None, - )[0] - - if (use_int8_w8a8 or use_int4_w4a8): - activated_output, activated_output_scale = torch_npu.npu_dequant_swiglu_quant( - x=gate_up_output, - weight_scale=w1_scale, - activation_scale=dynamic_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=expert_tokens, - activate_left=True, - quant_mode=1, - ) - activated_output_scale = [activated_output_scale] - else: - activated_output = torch_npu.npu_swiglu(gate_up_output) - activated_output_scale = None - - down_output = torch_npu.npu_grouped_matmul( - x=[activated_output], - weight=[w2], - scale=down_scale, - per_token_scale=activated_output_scale, - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=expert_tokens, - output_dtype=down_output_dtype, - )[0] - - moe_comm_method.unpermute(down_output, hidden_states) - - return hidden_states - - def fused_experts_moge( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -259,7 +157,7 @@ def forward_oot_v01011( logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor: - topk_weights, topk_ids, _ = select_experts( + topk_weights, topk_ids, row_idx = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -287,15 +185,15 @@ def forward_oot_v01011( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - global_num_experts=global_num_experts, - expert_map=expert_map, - ) + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + global_num_experts=global_num_experts, + expert_map=expert_map) def forward_oot( @@ -321,7 +219,7 @@ def forward_oot( logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor: - topk_weights, topk_ids, _ = select_experts( + topk_weights, topk_ids, row_idx = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -349,15 +247,15 @@ def forward_oot( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - global_num_experts=global_num_experts, - expert_map=expert_map, - ) + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + global_num_experts=global_num_experts, + expert_map=expert_map) def process_weights_after_loading(self, layer): diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index e5b4dff35b..e42fdc0d52 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -42,8 +42,8 @@ from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer -from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp +from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.sequence_parallel import MetadataForPadding from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dispose_tensor, get_all_reduce_merge_state, @@ -358,7 +358,7 @@ def __init__( ep_size = (get_ep_group().world_size if vllm_config.parallel_config.enable_expert_parallel else 1) - from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ + from vllm_ascend.ops.moe.token_dispatcher import \ setup_token_dispatchers setup_token_dispatchers( ep_size, diff --git a/vllm_ascend/ops/layers/__init__.py b/vllm_ascend/ops/moe/__init__.py similarity index 100% rename from vllm_ascend/ops/layers/__init__.py rename to vllm_ascend/ops/moe/__init__.py diff --git a/vllm_ascend/ops/layers/experts_selector.py b/vllm_ascend/ops/moe/experts_selector.py similarity index 100% rename from vllm_ascend/ops/layers/experts_selector.py rename to vllm_ascend/ops/moe/experts_selector.py diff --git a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py new file mode 100644 index 0000000000..b07c48971a --- /dev/null +++ b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py @@ -0,0 +1,240 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +from abc import ABC, abstractmethod + +import torch +import torch.distributed as dist +import torch.nn as nn +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import ( + get_dp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe import FusedMoEConfig + + +class FusedMoEPrepareAndFinalize(ABC): + + def __init__(self, moe_config: FusedMoEConfig): + self.moe_config = moe_config + + @abstractmethod + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + rm_router_logits: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + raise NotImplementedError("Prepare not implemented.") + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + raise NotImplementedError("Combine function not implemented.") + + +class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): + + def __init__(self, moe_config: FusedMoEConfig): + super().__init__(moe_config) + self._restore_tp_across_dp() + + def _restore_tp_across_dp(self): + # NOTE: Since vLLM flatten tp across dp, we need to restore the original + # tp_size and tp_rank. + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + rm_router_logits: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """The target_pad_length is calculated in forward_context, here we pad the + hidden states and router logits. And if TP size > 1, we also need to split + the tensors accordingly. + """ + self.replace_allreduce = replace_allreduce + self.enable_shared_expert_dp = enable_shared_expert_dp + + if not self.replace_allreduce: + self.num_tokens, _ = hidden_states.shape + forward_context = get_forward_context() + mc2_mask = forward_context.mc2_mask + target_pad_length = forward_context.padded_num_tokens + pad_size = target_pad_length - self.num_tokens + + if pad_size > 0 and not self.enable_shared_expert_dp: + hidden_states = nn.functional.pad(hidden_states, + (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, + (0, 0, 0, pad_size)) + + if self.tp_size > 1: + if not self.enable_shared_expert_dp: + split_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + split_router_logits = torch.tensor_split(router_logits, + self.tp_size, + dim=0) + hidden_states = split_hidden_states[self.tp_rank] + router_logits = split_router_logits[self.tp_rank] + self.split_hidden_states = split_hidden_states + + split_mc2_mask = torch.tensor_split(mc2_mask, + self.tp_size, + dim=0) + mc2_mask = split_mc2_mask[self.tp_rank] + + return hidden_states, router_logits, mc2_mask + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """If TP size > 1, all-gather the hidden states to get the final output. + + Also, unpad the hidden states if needed. + """ + if not (self.enable_shared_expert_dp or self.replace_allreduce): + if self.tp_size > 1: + dist.all_gather(list(self.split_hidden_states), hidden_states, + self.moe_config.tp_group.device_group) + hidden_states = torch.cat(self.split_hidden_states, dim=0) + + if self.num_tokens < hidden_states.shape[0]: + hidden_states = hidden_states[:self.num_tokens] + + return hidden_states + + +class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize): + + def __init__(self, moe_config: FusedMoEConfig): + super().__init__(moe_config) + self._restore_tp_across_dp() + + def _restore_tp_across_dp(self): + # NOTE: Since vLLM flatten tp across dp, we need to restore the original + # tp_size and tp_rank. + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + rm_router_logits: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + self.replace_allreduce = replace_allreduce + self.enable_shared_expert_dp = enable_shared_expert_dp + + if not (self.replace_allreduce or self.enable_shared_expert_dp): + self.num_tokens, _ = hidden_states.shape + pad_size = self.tp_size - self.num_tokens + + if pad_size > 0: + hidden_states = nn.functional.pad(hidden_states, + (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, + (0, 0, 0, pad_size)) + + if self.tp_size > 1: + split_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + split_router_logits = torch.tensor_split(router_logits, + self.tp_size, + dim=0) + self.split_hidden_states = split_hidden_states + + hidden_states = split_hidden_states[self.tp_rank] + router_logits = split_router_logits[self.tp_rank] + + return hidden_states, router_logits, None + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """If TP size > 1, all-gather the hidden states to get the final output. + + Also, unpad the hidden states if needed. + """ + if not (self.enable_shared_expert_dp or self.replace_allreduce): + if self.tp_size > 1: + dist.all_gather(list(self.split_hidden_states), hidden_states, + self.moe_config.tp_group.device_group) + hidden_states = torch.cat(self.split_hidden_states, dim=0) + + if self.num_tokens < hidden_states.shape[0]: + hidden_states = hidden_states[:self.num_tokens] + + return hidden_states + + +class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): + + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + rm_router_logits: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """When DP size > 1, pad the hidden states and router logits for communication.""" + self.rm_router_logits = rm_router_logits + self.enable_shared_expert_dp = enable_shared_expert_dp + + if self.moe_config.dp_size > 1: + forward_context = get_forward_context() + max_tokens_across_dp = forward_context.max_tokens_across_dp + + self.num_tokens = hidden_states.shape[0] + pad_size = max_tokens_across_dp - self.num_tokens + if pad_size > 0: + hidden_states = nn.functional.pad(hidden_states, + (0, 0, 0, pad_size)) + if not self.rm_router_logits: + router_logits = nn.functional.pad(router_logits, + (0, 0, 0, pad_size)) + + hidden_states = self.moe_config.dp_group.all_gather( + hidden_states, 0) + if self.rm_router_logits: + router_logits, _ = gate(hidden_states) + else: + router_logits = self.moe_config.dp_group.all_gather( + router_logits, 0) + + return hidden_states, router_logits, None + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """When DP size > 1, reduce-scatter the hidden states to get the final output. + + When TP size > 1, all-reduce the hidden states to get the final output. + """ + if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp: + hidden_states = get_dp_group().reduce_scatter(hidden_states, 0) + hidden_states = hidden_states[:self.num_tokens] + + if reduce_results and (self.moe_config.tp_size > 1 + or self.moe_config.ep_size > 1): + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + + return hidden_states diff --git a/vllm_ascend/ops/moe/moe_comm_method.py b/vllm_ascend/ops/moe/moe_comm_method.py new file mode 100644 index 0000000000..af46a3fbc2 --- /dev/null +++ b/vllm_ascend/ops/moe/moe_comm_method.py @@ -0,0 +1,298 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +from abc import ABC, abstractmethod +from typing import Any, Optional + +import torch +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe import FusedMoEConfig + +from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import ( + FusedMoEPrepareAndFinalizeWithAll2All, + FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2) +from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp +from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV, + TokenDispatcherWithAllGather, + TokenDispatcherWithMC2) + + +class MoECommMethod(ABC): + """Base class for MoE communication methods.""" + + def __init__(self, moe_config: FusedMoEConfig): + self.moe_config = moe_config + self.mc2_mask = None + + self.token_dispatcher = self._get_token_dispatcher() + self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize( + ) + + def prepare(self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + rm_router_logits: bool = False, + replace_allreduce: bool = False, + gate=None) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare( + hidden_states, router_logits, enable_shared_expert_dp, + rm_router_logits, replace_allreduce, gate) + self.mc2_mask = mc2_mask + return hidden_states, router_logits + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + hidden_states = self.fused_moe_prepare_finalize.finalize( + hidden_states, reduce_results) + return hidden_states + + def fused_experts( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + row_idx: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_int8_w8a8: bool = False, + use_int4_w4a8: bool = False, + global_num_experts: Optional[int] = None, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + # For TorchAir graph + is_torchair: bool = False, + # For Cube/Vector parallel + shared_experts: Optional[Any] = None, + shared_gate_up: Optional[Any] = None, + shared_dequant_scale: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + # For load balance + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + need_trans: bool = False) -> torch.Tensor: + # Check constraints + assert hidden_states.shape[1] == w1.shape[1], ( + f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}") + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous( + ), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + moe_comm_method = get_forward_context().moe_comm_method + assert moe_comm_method is not None, "Missing communication context" + + results = self.token_dispatcher.token_dispatch( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + expert_map=expert_map, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + shared_gate_up=shared_gate_up, + shared_dequant_scale=shared_dequant_scale, + mc2_mask=self.mc2_mask, + apply_router_weight_on_input=apply_router_weight_on_input, + with_quant=use_int8_w8a8 or use_int4_w4a8) + + permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = \ + results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"] + + mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=expert_tokens, + dynamic_scale=dynamic_scale, + group_list_type=group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + with_quant=use_int8_w8a8 + or use_int4_w4a8, + need_trans=need_trans) + + hidden_states[:] = self.token_dispatcher.token_combine( + hidden_states=mlp_output) + + return hidden_states + + @abstractmethod + def _get_token_dispatcher(self): + raise NotImplementedError( + "_get_token_dispatcher function not implemented.") + + @abstractmethod + def _get_fused_moe_prepare_finalize(self): + raise NotImplementedError( + "_get_fused_moe_prepare_finalize function not implemented.") + + +class AllGatherCommImpl(MoECommMethod): + """This implementation is the same as NativeAllGatherCommImpl, + but uses NPU-specific ops for better performance. + + This implementation should be compatible with all scenarios, and + thus it is the default implementation for MoE communication methods. + It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing + and `torch_npu.npu_moe_token_unpermute` for post-processing + to handle the token-to-expert mapping and communication efficiently. + + NOTE(Yizhou): TBH, it is really weird that we were supposed to use + `torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing` + or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute` + for pre-processing and post-processing, respectively. + But `npu_moe_finalize_routing` will lead to accuracy issues so we have to + use `torch_npu.npu_moe_token_unpermute` instead. + This is a workaround and should be removed after the issue is fixed. + """ + + def _get_token_dispatcher(self): + return TokenDispatcherWithAllGather( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) + + def _get_fused_moe_prepare_finalize(self): + return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config) + + +class NativeAllGatherCommImpl(AllGatherCommImpl): + """This implementation should be compatible with all scenarios. + + Note that this implementation purely consists of native PyTorch ops + and does not use any NPU-specific ops. So the performance may not be optimal. + But it is a good fallback for scenarios where NPU-specific ops are not available. + """ + + def permute( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + expert_map: torch.Tensor, + num_experts: int, + apply_a8_quantization: bool, + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: + num_tokens = hidden_states.shape[0] + + # Generate token indices and flatten + token_indices = torch.arange(num_tokens, + device=hidden_states.device, + dtype=torch.int64) + token_indices = (token_indices.unsqueeze(1).expand( + -1, self.moe_config.experts_per_token).reshape(-1)) + + # Flatten token-to-expert mappings and map to local experts + weights_flat = topk_weights.view(-1) + experts_flat = topk_ids.view(-1) + local_experts_flat = (expert_map[experts_flat] + if expert_map is not None else experts_flat) + + # Filter valid token-expert pairs + mask = local_experts_flat != -1 + # FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...] + # So we need to filter out invalid tokens by zeroing their weights. + # This is a workaround and should be removed after the issue is fixed + filtered_weights = torch.where(mask, weights_flat, + torch.zeros_like(weights_flat)).to( + topk_weights.dtype) + filtered_experts = torch.where( + mask, + local_experts_flat, + torch.full_like(local_experts_flat, num_experts), + ).to(topk_ids.dtype) + + # Sort by local expert IDs + sort_indices = torch.argsort(filtered_experts.view(torch.float32)) + self.sorted_token_indices = token_indices[sort_indices] + self.sorted_weights = filtered_weights[sort_indices] + + # Compute token counts with minlength of num_experts + # This is equivalent to but faster than: + # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] + token_counts = torch.zeros(num_experts + 1, + device=hidden_states.device, + dtype=torch.int64) + ones = torch.ones_like(filtered_experts, dtype=torch.int64) + token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) + expert_tokens = token_counts[:num_experts] + + # Rearrange hidden_states + permuted_hidden_states = hidden_states[self.sorted_token_indices] + + group_list_type = 1 # `count` mode + + return permuted_hidden_states, expert_tokens, None, group_list_type + + def unpermute(self, mlp_output: torch.Tensor, + hidden_states: torch.Tensor) -> None: + mlp_output = mlp_output * self.sorted_weights.unsqueeze(1) + + final_hidden_states = torch.zeros_like(hidden_states) + final_hidden_states.index_add_(0, self.sorted_token_indices, + mlp_output) + + hidden_states[:] = final_hidden_states + + +class MC2CommImpl(MoECommMethod): + """This implementation is for the scenarios listed below: + 1. `enable_expert_parallel=True`. + 2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available. + 3. `enable_expert_parallel=False` is not supported. + + This implementation uses the MC2 communication method, which is optimized for + Communication and Computation parallelism on Ascend devices. + """ + + def _get_token_dispatcher(self): + return TokenDispatcherWithMC2() + + def _get_fused_moe_prepare_finalize(self): + return FusedMoEPrepareAndFinalizeWithMC2(self.moe_config) + + +class AlltoAllCommImpl(MoECommMethod): + """This implementation is for the scenarios listed below: + 1. `enable_expert_parallel=True`. + 2. `npu_grouped_matmul` is available. + + This implementation uses all-to-all communication to exchange tokens + between data parallel ranks before and after the MLP computation. It should + have better performance than AllGatherCommImpl when DP size > 1. + """ + + def _get_token_dispatcher(self): + return TokenDispatcherWithAll2AllV( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) + + def _get_fused_moe_prepare_finalize(self): + return FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config) diff --git a/vllm_ascend/ops/layers/moe_mlp.py b/vllm_ascend/ops/moe/moe_mlp.py similarity index 90% rename from vllm_ascend/ops/layers/moe_mlp.py rename to vllm_ascend/ops/moe/moe_mlp.py index d6f67bb1f1..77e8318434 100644 --- a/vllm_ascend/ops/layers/moe_mlp.py +++ b/vllm_ascend/ops/moe/moe_mlp.py @@ -176,14 +176,18 @@ def quant_apply_mlp(hidden_states: torch.Tensor, return hidden_states -def unquant_apply_mlp( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - group_list: torch.Tensor, - group_list_type: int = 1, - topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor: - w1 = w1.transpose(1, 2) +def unquant_apply_mlp(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1, + topk_scales: Optional[torch.Tensor] = None, + need_trans: bool = True) -> torch.Tensor: + + if need_trans: + w1 = w1.transpose(1, 2) + w2 = w2.transpose(1, 2) + gate_up_out = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w1], @@ -201,7 +205,6 @@ def unquant_apply_mlp( if topk_scales is not None: gate_up_out *= topk_scales - w2 = w2.transpose(1, 2) hidden_states = torch_npu.npu_grouped_matmul( x=[gate_up_out], weight=[w2], @@ -225,7 +228,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor, w2_scale_bias: torch.Tensor = None, topk_scales: Optional[torch.Tensor] = None, with_quant: bool = False, - fusion: bool = False) -> torch.Tensor: + fusion: bool = False, + need_trans: bool = True) -> torch.Tensor: if with_quant: return quant_apply_mlp(hidden_states=hidden_states, w1=w1, @@ -244,4 +248,5 @@ def unified_apply_mlp(hidden_states: torch.Tensor, w2=w2, group_list=group_list, group_list_type=group_list_type, - topk_scales=topk_scales) + topk_scales=topk_scales, + need_trans=need_trans) diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py similarity index 100% rename from vllm_ascend/ops/moe_dispatcher/token_dispatcher.py rename to vllm_ascend/ops/moe/token_dispatcher.py diff --git a/vllm_ascend/ops/moe_dispatcher/__init__.py b/vllm_ascend/ops/moe_dispatcher/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 72f956d1d2..47aa99cfc0 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -27,7 +27,7 @@ from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe import unified_fused_experts_eager -from vllm_ascend.ops.layers.experts_selector import select_experts +from vllm_ascend.ops.moe.experts_selector import select_experts class AscendW4A8DynamicLinearMethod: diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index e4cbdc897c..010d45da41 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -23,7 +23,7 @@ from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.ops.layers.experts_selector import select_experts +from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index f710bd2f22..54114b7195 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -27,10 +27,8 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.common_fused_moe import \ - fused_experts as unified_fused_experts from vllm_ascend.ops.fused_moe import unified_fused_experts_eager -from vllm_ascend.ops.layers.experts_selector import select_experts +from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ @@ -221,12 +219,14 @@ def apply( global_num_experts=global_num_experts) if self.use_aclgraph: - return unified_fused_experts( + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=row_idx, use_int8_w8a8=True, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale,