1818import pytest
1919import torch
2020import torch .nn as nn
21+ import torch_npu
2122from pytest_mock import MockerFixture
2223
24+ from vllm_ascend .ascend_forward_context import get_fused_moe_state
2325from vllm_ascend .ops .fused_moe import (AscendFusedMoE ,
2426 AscendUnquantizedFusedMoEMethod )
25- from vllm_ascend .utils import adapt_patch , AscendSocVersion # noqa E402
26- from vllm_ascend .ascend_forward_context import get_fused_moe_state
27+ from vllm_ascend .utils import AscendSocVersion , adapt_patch # noqa E402
2728
2829adapt_patch (True )
2930
@@ -107,15 +108,9 @@ def mock_moe_env(mocker: MockerFixture):
107108 patch ("torch_npu.npu_moe_distribute_dispatch" , return_value = (
108109 torch .randn (16 , 2 )
109110 )), \
110- patch ("torch_npu.npu_moe_distribute_dispatch_v2" , return_value = (
111- torch .randn (16 , 2 )
112- )), \
113111 patch ("torch_npu.npu_moe_distribute_combine" , return_value = (
114112 torch .randn (16 , 2 )
115113 )), \
116- patch ("torch_npu.npu_moe_distribute_combine_v2" , return_value = (
117- torch .randn (16 , 2 )
118- )), \
119114 patch ("torch_npu.npu_grouped_matmul" , return_value = (
120115 (torch .randn (8 , 2 ), torch .randn (8 , 2 ))
121116 )), \
@@ -130,7 +125,14 @@ def mock_moe_env(mocker: MockerFixture):
130125 patch ("torch_npu.npu_moe_finalize_routing" , return_value = (
131126 torch .randn (16 , 2 )
132127 )):
133- yield
128+ if hasattr (torch_npu , 'npu_moe_distribute_dispatch_v2' ):
129+ with patch ("torch_npu.npu_moe_distribute_dispatch_v2" , return_value = (
130+ torch .randn (16 , 2 ))), \
131+ patch ("torch_npu.npu_moe_distribute_combine_v2" , return_value = (
132+ torch .randn (16 , 2 ))):
133+ yield
134+ else :
135+ yield
134136
135137
136138@pytest .fixture
@@ -245,8 +247,11 @@ def test_forward(self, mock_dist_env, default_moe_config, others_param):
245247 moe .moe_parallel_config .ep_size = 1
246248
247249 moe .quant_method = MockQuantMethod (shared_experts , num_tokens )
248- forward_context = MagicMock (mc2_mask = torch .zeros (num_tokens , dtype = torch .bool ), padded_num_tokens = num_tokens )
249- with patch ("vllm_ascend.ops.fused_moe.get_forward_context" , return_value = forward_context ):
250+ forward_context = MagicMock (mc2_mask = torch .zeros (num_tokens ,
251+ dtype = torch .bool ),
252+ padded_num_tokens = num_tokens )
253+ with patch ("vllm_ascend.ops.fused_moe.get_forward_context" ,
254+ return_value = forward_context ):
250255 output = moe .forward (inputs ,
251256 router_logits ,
252257 is_prefill = is_prefill ,
@@ -306,7 +311,8 @@ def test_apply_without_expert_map(self, moe_method, mock_dist_env,
306311 global_num_experts , ep_size , select_softmax = others_param
307312 is_prefill = False
308313 is_deepseek_v3_r1 = global_num_experts == 256
309- forward_context = MagicMock (fused_moe_state = get_fused_moe_state (ep_size , is_prefill , is_deepseek_v3_r1 ))
314+ forward_context = MagicMock (fused_moe_state = get_fused_moe_state (
315+ ep_size , is_prefill , is_deepseek_v3_r1 ))
310316 with patch (
311317 "vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS" ,
312318 select_softmax ), \
@@ -342,7 +348,8 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
342348 """
343349 ep_size , alltoall_buffer = others_param
344350 is_prefill = False
345- forward_context = MagicMock (fused_moe_state = get_fused_moe_state (ep_size , is_prefill , True ))
351+ forward_context = MagicMock (
352+ fused_moe_state = get_fused_moe_state (ep_size , is_prefill , True ))
346353 with patch ("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER" ,
347354 alltoall_buffer ), \
348355 patch ("vllm_ascend.ops.fused_moe.get_forward_context" , return_value = forward_context ), \
0 commit comments