1818import pytest
1919import torch
2020import torch .nn as nn
21+ import torch_npu
2122from pytest_mock import MockerFixture
2223
24+ from vllm_ascend .ascend_forward_context import get_fused_moe_state
2325from vllm_ascend .ops .fused_moe import (AscendFusedMoE ,
2426 AscendUnquantizedFusedMoEMethod )
25- from vllm_ascend .utils import adapt_patch # noqa E402
27+ from vllm_ascend .utils import AscendSocVersion , adapt_patch # noqa E402
2628
2729adapt_patch (True )
2830
2931
30- def mock_ep_group (mocker ):
32+ def mock_ep_and_mc2_group (mocker ):
3133 mock_group = mocker .MagicMock ()
3234 mock_group .rank_in_group = 0
3335 mock_group .rank = 0
@@ -52,7 +54,8 @@ def mock_dist_env(mocker: MockerFixture):
5254
5355 with patch ('torch.distributed.get_rank' , return_value = 0 ), \
5456 patch ('torch.distributed.get_world_size' , return_value = 4 ), \
55- patch ('vllm_ascend.ops.fused_moe.get_ep_group' , return_value = mock_ep_group (mocker )), \
57+ patch ('vllm_ascend.ops.fused_moe.get_ep_group' , return_value = mock_ep_and_mc2_group (mocker )), \
58+ patch ('vllm_ascend.ops.fused_moe.get_mc2_group' , return_value = mock_ep_and_mc2_group (mocker )), \
5659 patch ('vllm_ascend.ops.fused_moe.get_tp_group' , return_value = mock_dp_and_tp_group (mocker )), \
5760 patch ('vllm.distributed.parallel_state.get_tp_group' , return_value = mock_dp_and_tp_group (mocker )), \
5861 patch ('vllm_ascend.ops.fused_moe.get_dp_group' , return_value = mock_dp_and_tp_group (mocker )), \
@@ -73,7 +76,7 @@ def mock_dist_env(mocker: MockerFixture):
7376 return_value = (3 , torch .tensor ([0 , 1 , 2 , - 1 , - 1 , - 1 , - 1 , - 1 ]))), \
7477 patch ('vllm_ascend.ops.fused_moe.get_forward_context' ,
7578 return_value = MagicMock (
76- attn_metadata = MagicMock ( max_num_tokens_across_dp = 10 ) ,
79+ max_tokens_across_dp = 10 ,
7780 dp_metadata = MagicMock (cu_tokens_across_dp_cpu = [5 , 10 ])
7881 )), \
7982 patch ('vllm_ascend.ops.fused_moe.get_current_vllm_config' ,
@@ -122,7 +125,14 @@ def mock_moe_env(mocker: MockerFixture):
122125 patch ("torch_npu.npu_moe_finalize_routing" , return_value = (
123126 torch .randn (16 , 2 )
124127 )):
125- yield
128+ if hasattr (torch_npu , 'npu_moe_distribute_dispatch_v2' ):
129+ with patch ("torch_npu.npu_moe_distribute_dispatch_v2" , return_value = (
130+ torch .randn (16 , 2 ))), \
131+ patch ("torch_npu.npu_moe_distribute_combine_v2" , return_value = (
132+ torch .randn (16 , 2 ))):
133+ yield
134+ else :
135+ yield
126136
127137
128138@pytest .fixture
@@ -237,11 +247,16 @@ def test_forward(self, mock_dist_env, default_moe_config, others_param):
237247 moe .moe_parallel_config .ep_size = 1
238248
239249 moe .quant_method = MockQuantMethod (shared_experts , num_tokens )
240- output = moe .forward (inputs ,
241- router_logits ,
242- is_prefill = is_prefill ,
243- top_k = top_k ,
244- shared_experts = shared_experts )
250+ forward_context = MagicMock (mc2_mask = torch .zeros (num_tokens ,
251+ dtype = torch .bool ),
252+ padded_num_tokens = num_tokens )
253+ with patch ("vllm_ascend.ops.fused_moe.get_forward_context" ,
254+ return_value = forward_context ):
255+ output = moe .forward (inputs ,
256+ router_logits ,
257+ is_prefill = is_prefill ,
258+ top_k = top_k ,
259+ shared_experts = shared_experts )
245260
246261 moe .quant_method .apply .assert_called_once ()
247262
@@ -288,15 +303,20 @@ def test_process_weights_after_loading(self, moe_method, mock_dist_env):
288303 def test_apply_without_expert_map (self , moe_method , mock_dist_env ,
289304 mock_moe_env , others_param ):
290305 """
291- 1 test is_deepseek_v3_r1=true and use fused_expters_with_all2all
306+ 1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all
292307 2 test use_select_experts and fused_experts
293308 3 test use select_gating_topk_softmax_experts and fused_experts
294309 4 test use select_experts and fused_experts_with_all2all_buffer
295310 """
296311 global_num_experts , ep_size , select_softmax = others_param
312+ is_prefill = False
313+ is_deepseek_v3_r1 = global_num_experts == 256
314+ forward_context = MagicMock (fused_moe_state = get_fused_moe_state (
315+ ep_size , is_prefill , is_deepseek_v3_r1 ))
297316 with patch (
298317 "vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS" ,
299- select_softmax ):
318+ select_softmax ), \
319+ patch ("vllm_ascend.ops.fused_moe.get_forward_context" , return_value = forward_context ):
300320 moe_method .ep_size = ep_size
301321 x = torch .randn (8 , 2 , 2 )
302322 router_logits = torch .randn (8 , 8 )
@@ -309,7 +329,7 @@ def test_apply_without_expert_map(self, moe_method, mock_dist_env,
309329 top_k = 2 ,
310330 renormalize = True ,
311331 global_num_experts = global_num_experts ,
312- is_prefill = False )
332+ is_prefill = is_prefill )
313333
314334 if ep_size == 1 :
315335 assert result .shape == (16 , 2 )
@@ -327,8 +347,13 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
327347 4 test use_select_experts and fused_experts
328348 """
329349 ep_size , alltoall_buffer = others_param
350+ is_prefill = False
351+ forward_context = MagicMock (
352+ fused_moe_state = get_fused_moe_state (ep_size , is_prefill , True ))
330353 with patch ("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER" ,
331- alltoall_buffer ):
354+ alltoall_buffer ), \
355+ patch ("vllm_ascend.ops.fused_moe.get_forward_context" , return_value = forward_context ), \
356+ patch ("vllm_ascend.ops.fused_moe.get_ascend_soc_version" , return_value = AscendSocVersion .A3 ):
332357 expert_map = torch .tensor ([0 , 1 , 2 , - 1 , - 1 , - 1 , - 1 , - 1 ])
333358 moe_method .ep_size = ep_size
334359 x = torch .randn (8 , 2 , 2 )
@@ -347,7 +372,7 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
347372 renormalize = True ,
348373 global_num_experts = 128 ,
349374 expert_map = expert_map ,
350- is_prefill = False )
375+ is_prefill = is_prefill )
351376
352377 if ep_size == 16 or ep_size == 1 :
353378 assert result .shape == (16 , 2 )
0 commit comments