1010
1111from tests .kernels .moe .utils import (batched_moe ,
1212 make_quantized_test_activations ,
13- make_test_weights , triton_moe )
13+ make_test_weights , naive_batched_moe )
1414from tests .kernels .quant_utils import native_batched_masked_quant_matmul
1515from tests .kernels .utils import torch_experts
1616from vllm .config import VllmConfig , set_current_vllm_config
3333 (45 , 512 , 512 ),
3434 (45 , 1024 , 128 ),
3535 (45 , 1024 , 2048 ),
36- (64 , 128 , 128 ),
3736 (64 , 512 , 512 ),
3837 (64 , 1024 , 2048 ),
3938 (222 , 128 , 128 ),
4039 (222 , 128 , 2048 ),
41- (222 , 512 , 512 ),
4240 (222 , 1024 , 128 ),
4341 (222 , 1024 , 2048 ),
4442]
@@ -95,11 +93,12 @@ def make_tensors(config: BatchedMMConfig):
9593@pytest .mark .parametrize ("max_tokens_per_expert" ,
9694 [32 , 64 , 128 , 192 , 224 , 256 , 512 ])
9795@pytest .mark .parametrize ("K" , [128 , 256 , 1024 ])
98- @pytest .mark .parametrize ("N" , [128 , 256 , 512 , 1024 ])
99- @pytest .mark .parametrize ("dtype" ,
100- [torch .float32 , torch .float16 , torch .bfloat16 ])
101- @pytest .mark .parametrize ("block_shape" , [None ])
102- @pytest .mark .parametrize ("per_act_token_quant" , [False ])
96+ @pytest .mark .parametrize ("N" , [128 , 256 , 1024 ])
97+ @pytest .mark .parametrize (
98+ "dtype" ,
99+ [torch .float8_e4m3fn , torch .float32 , torch .float16 , torch .bfloat16 ])
100+ @pytest .mark .parametrize ("block_shape" , [None , [128 , 128 ]])
101+ @pytest .mark .parametrize ("per_act_token_quant" , [False , True ])
103102def test_batched_mm (num_experts : int , max_tokens_per_expert : int , K : int ,
104103 N : int , dtype : torch .dtype ,
105104 block_shape : Optional [list [int ]],
@@ -134,7 +133,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
134133 in_dtype = act_dtype ,
135134 quant_dtype = quant_dtype ,
136135 block_shape = block_shape ,
137- per_act_token_quant = per_act_token_quant )
136+ per_act_token_quant = per_act_token_quant ,
137+ )
138138
139139 B , B_q , B_scale , _ , _ , _ = make_test_weights (
140140 num_experts ,
@@ -143,6 +143,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
143143 in_dtype = act_dtype ,
144144 quant_dtype = quant_dtype ,
145145 block_shape = block_shape ,
146+ per_act_token_quant = per_act_token_quant ,
146147 )
147148
148149 out_shape = (num_experts , max_tokens_per_expert , N )
@@ -177,6 +178,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
177178 "BLOCK_SIZE_N" : 16 ,
178179 "BLOCK_SIZE_K" : 16 if dtype .itemsize > 1 else 32
179180 },
181+ per_act_token_quant = per_act_token_quant ,
180182 block_shape = block_shape ,
181183 )
182184
@@ -185,32 +187,31 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
185187 B ,
186188 ref_output ,
187189 num_expert_tokens ,
188- None ,
189- None ,
190- None ,
191190 )
192191
193192 q_ref_output = native_batched_masked_quant_matmul (A_q , B_q , q_ref_output ,
194193 num_expert_tokens ,
195194 A_scale , B_scale ,
196- block_shape )
195+ block_shape ,
196+ per_act_token_quant )
197197
198198 rtol , atol = {
199199 torch .float16 : (6e-2 , 6e-2 ),
200200 torch .bfloat16 : (6e-2 , 6e-2 ),
201201 torch .float32 : (1e-2 , 1e-2 ),
202202 }[test_output .dtype ]
203203
204- torch .testing .assert_close (ref_output , test_output , atol = atol , rtol = rtol )
204+ torch .testing .assert_close (ref_output , q_ref_output , atol = atol , rtol = rtol )
205205 torch .testing .assert_close (test_output , q_ref_output , atol = atol , rtol = rtol )
206206
207207
208208@pytest .mark .parametrize (("m" , "n" , "k" ), MNK_FACTORS )
209209@pytest .mark .parametrize ("e" , NUM_EXPERTS )
210210@pytest .mark .parametrize ("topk" , TOP_KS )
211- @pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
212- @pytest .mark .parametrize ("per_act_token_quant" , [False ])
213- @pytest .mark .parametrize ("block_shape" , [None ])
211+ @pytest .mark .parametrize ("dtype" , [torch .float8_e4m3fn , torch .bfloat16 ])
212+ @pytest .mark .parametrize ("per_act_token_quant" , [False , True ])
213+ @pytest .mark .parametrize ("block_shape" , [None , [128 , 128 ]])
214+ @pytest .mark .parametrize ("input_scales" , [False ])
214215def test_fused_moe_batched_experts (
215216 m : int ,
216217 n : int ,
@@ -220,15 +221,19 @@ def test_fused_moe_batched_experts(
220221 dtype : torch .dtype ,
221222 per_act_token_quant : bool ,
222223 block_shape : Optional [list [int ]],
224+ input_scales : bool ,
223225):
224226 current_platform .seed_everything (7 )
225227
226228 use_fp8_w8a8 = dtype == torch .float8_e4m3fn
227229
230+ if topk > e :
231+ pytest .skip ("topk > e" )
232+
228233 if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None ):
229234 pytest .skip ("Skip quantization test for non-quantized type" )
230235
231- if per_act_token_quant and block_shape is not None or topk > e :
236+ if per_act_token_quant and block_shape is not None :
232237 pytest .skip ("Skip illegal quantization test." )
233238
234239 a = torch .randn ((m , k ), device = "cuda" , dtype = torch .bfloat16 ) / 10
@@ -241,55 +246,74 @@ def test_fused_moe_batched_experts(
241246 act_dtype = dtype
242247 quant_dtype = None
243248
244- _ , w1 , w1_s , _ , w2 , w2_s = make_test_weights (e ,
245- n ,
246- k ,
247- block_shape = block_shape ,
248- in_dtype = act_dtype ,
249- quant_dtype = quant_dtype )
249+ w1_16 , w1 , w1_s , w2_16 , w2 , w2_s = make_test_weights (
250+ e ,
251+ n ,
252+ k ,
253+ block_shape = block_shape ,
254+ in_dtype = act_dtype ,
255+ quant_dtype = quant_dtype ,
256+ per_act_token_quant = per_act_token_quant ,
257+ )
258+
259+ if input_scales and quant_dtype is not None :
260+ a1_scale = torch .tensor (1 , device = "cuda" , dtype = torch .float32 )
261+ a2_scale = torch .tensor (1 , device = "cuda" , dtype = torch .float32 )
262+ else :
263+ a1_scale = None
264+ a2_scale = None
250265
251266 with set_current_vllm_config (vllm_config ):
252267 topk_weight , topk_ids , _ = fused_topk (a , score , topk , False )
253- batched_output = batched_moe (
268+
269+ baseline_output = torch_experts (
254270 a ,
255271 w1 ,
256272 w2 ,
257273 topk_weight ,
258274 topk_ids ,
259275 w1_scale = w1_s ,
260276 w2_scale = w2_s ,
277+ a1_scale = a1_scale ,
278+ a2_scale = a2_scale ,
261279 quant_dtype = quant_dtype ,
262280 per_act_token_quant = per_act_token_quant ,
263281 block_shape = block_shape ,
264282 )
265- baseline_output = torch_experts (
283+
284+ batched_output = naive_batched_moe (
266285 a ,
267286 w1 ,
268287 w2 ,
269288 topk_weight ,
270289 topk_ids ,
271290 w1_scale = w1_s ,
272291 w2_scale = w2_s ,
292+ a1_scale = a1_scale ,
293+ a2_scale = a2_scale ,
273294 quant_dtype = quant_dtype ,
274295 per_act_token_quant = per_act_token_quant ,
275- block_shape = block_shape )
296+ block_shape = block_shape ,
297+ )
276298
277- triton_output = triton_moe (
299+ triton_output = batched_moe (
278300 a ,
279301 w1 ,
280302 w2 ,
281303 topk_weight ,
282304 topk_ids ,
283305 w1_scale = w1_s ,
284306 w2_scale = w2_s ,
307+ a1_scale = a1_scale ,
308+ a2_scale = a2_scale ,
285309 quant_dtype = quant_dtype ,
286310 per_act_token_quant = per_act_token_quant ,
287311 block_shape = block_shape ,
288312 )
289313
290- torch .testing .assert_close (triton_output ,
314+ torch .testing .assert_close (batched_output ,
291315 baseline_output ,
292- atol = 2e -2 ,
316+ atol = 3e -2 ,
293317 rtol = 2e-2 )
294318
295319 torch .testing .assert_close (triton_output ,
0 commit comments