2323# here to make the test pass.
2424import vllm_ascend .patch .worker .patch_common .patch_utils # type: ignore[import] # isort: skip # noqa
2525
26+ from unittest .mock import MagicMock , patch
27+
2628import pytest
2729import torch
2830from vllm .model_executor .layers .activation import SiluAndMul
2931
30- from vllm_ascend .ops .fused_moe import fused_experts
32+ from vllm_ascend .ascend_forward_context import FusedMoEState
33+ from vllm_ascend .ops .fused_moe import (AscendUnquantizedFusedMoEMethod ,
34+ fused_experts )
3135
3236NUM_EXPERTS = [8 , 64 ]
3337EP_SIZE = [1 , 4 ]
@@ -98,3 +102,119 @@ def test_fused_experts(
98102 # TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
99103 torch .testing .assert_close (output , torch_output , atol = 4e-2 , rtol = 1 )
100104 torch .npu .empty_cache ()
105+
106+
107+ @pytest .mark .parametrize ("m" , [1 , 33 , 64 ])
108+ @pytest .mark .parametrize ("n" , [128 , 1024 , 2048 ])
109+ @pytest .mark .parametrize ("k" , [128 , 511 ])
110+ @pytest .mark .parametrize ("e" , NUM_EXPERTS )
111+ @pytest .mark .parametrize ("topk" , TOP_KS )
112+ @pytest .mark .parametrize ("renormalize" , [True , False ])
113+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
114+ @pytest .mark .parametrize ("device" , DEVICE )
115+ def test_ascend_unquantized_fused_moe_softmax (
116+ m : int ,
117+ n : int ,
118+ k : int ,
119+ e : int ,
120+ topk : int ,
121+ renormalize : bool ,
122+ dtype : torch .dtype ,
123+ device : str ,
124+ ):
125+
126+ class MockVllmConfig :
127+
128+ @property
129+ def scheduler_config (self ):
130+
131+ class SchedulerConfig :
132+ max_num_seqs = 256
133+
134+ return SchedulerConfig ()
135+
136+ @property
137+ def model_config (self ):
138+
139+ class ModelConfig :
140+ max_model_len = 2048
141+
142+ return ModelConfig ()
143+
144+ class MockAscendConfig :
145+
146+ @property
147+ def torchair_graph_config (self ):
148+
149+ class TorchairGraphConfig :
150+ enabled = False
151+
152+ return TorchairGraphConfig ()
153+
154+ class MockMC2Group :
155+
156+ @property
157+ def device_group (self ):
158+ return MagicMock ()
159+
160+ class MockForwardContext :
161+
162+ @property
163+ def fused_moe_state (self ):
164+ return FusedMoEState .AllGather
165+
166+ class MockLayer (torch .nn .Module ):
167+
168+ def __init__ (self ):
169+ super ().__init__ ()
170+ self .w13_weight = torch .randn (
171+ (e , 2 * n , k ), device = device , dtype = dtype ) / 10
172+ self .w2_weight = torch .randn (
173+ (e , k , n ), device = device , dtype = dtype ) / 10
174+
175+ x = torch .randn ((m , k ), device = device , dtype = dtype ) / 10
176+ router_logits = torch .randn ((m , e ), device = device , dtype = dtype )
177+
178+ with patch ('vllm_ascend.ops.fused_moe.get_current_vllm_config' ) as mock_get_vllm_config , \
179+ patch ('vllm_ascend.ops.fused_moe.get_ascend_config' ) as mock_get_ascend_config , \
180+ patch ('vllm_ascend.ops.fused_moe.get_mc2_group' ) as mock_get_mc2_group , \
181+ patch ('vllm_ascend.ops.fused_moe.get_forward_context' ) as mock_get_context , \
182+ patch ('vllm_ascend.ops.fused_moe.fused_experts' ) as mock_fused_experts , \
183+ patch ('torch.distributed.get_rank' ) as mock_get_rank :
184+ mock_get_vllm_config .return_value = MockVllmConfig ()
185+ mock_get_ascend_config .return_value = MockAscendConfig ()
186+ mock_get_mc2_group .return_value = MockMC2Group ()
187+ mock_get_context .return_value = MockForwardContext ()
188+ mock_fused_experts .return_value = torch .zeros_like (x )
189+
190+ mock_get_rank .side_effect = AttributeError ("mock error" )
191+
192+ method = AscendUnquantizedFusedMoEMethod ()
193+ layer = MockLayer ()
194+
195+ output = method .apply (
196+ layer = layer ,
197+ x = x ,
198+ router_logits = router_logits ,
199+ top_k = topk ,
200+ renormalize = renormalize ,
201+ scoring_func = "softmax" ,
202+ global_num_experts = e ,
203+ )
204+
205+ assert method .moe_all_to_all_group_name is None
206+
207+ assert mock_fused_experts .called
208+ call_args = mock_fused_experts .call_args [1 ]
209+
210+ topk_weights = call_args ['topk_weights' ]
211+ topk_ids = call_args ['topk_ids' ]
212+
213+ assert topk_weights .shape == (m , topk )
214+ assert topk_ids .shape == (m , topk )
215+ assert output .shape == (m , k )
216+ assert topk_weights .dtype == dtype
217+ assert topk_ids .dtype == torch .int32
218+ assert output .dtype == dtype
219+
220+ torch .npu .empty_cache ()
0 commit comments