2323from triton_kernels .testing import assert_close
2424
2525from vllm .model_executor .layers .fused_moe .config import FusedMoEQuantConfig
26- from vllm .model_executor .layers .fused_moe .fused_batched_moe import (
27- BatchedPrepareAndFinalize ,
28- )
29- from vllm .model_executor .layers .fused_moe .fused_moe import fused_topk
3026from vllm .model_executor .layers .fused_moe .gpt_oss_triton_kernels_moe import (
31- BatchedOAITritonExperts ,
3227 triton_kernel_moe_forward ,
3328)
34- from vllm .model_executor .layers .fused_moe .modular_kernel import FusedMoEModularKernel
3529from vllm .model_executor .layers .utils import shuffle_weight
3630from vllm .utils import round_up
3731
@@ -302,8 +296,8 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
302296 quant_config = FusedMoEQuantConfig .make (
303297 w1_bias = w1_bias_tri ,
304298 w2_bias = w2_bias_tri ,
305- w1_precision = pc1 ,
306- w2_precision = pc2 ,
299+ w1_scale = pc1 ,
300+ w2_scale = pc2 ,
307301 )
308302
309303 out_triton_monolithic = triton_kernel_moe_forward (
@@ -329,115 +323,6 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
329323 assert_close (ref = out_ref , tri = out_triton_monolithic , maxtol = 0.025 , rmstol = 0.005 )
330324
331325
332- def batched_moe (
333- a : torch .Tensor ,
334- w1 ,
335- w2 ,
336- gating_output : torch .Tensor ,
337- topk : int ,
338- renormalize : bool ,
339- w1_bias : torch .Tensor ,
340- w2_bias : torch .Tensor ,
341- w1_precision : PrecisionConfig ,
342- w2_precision : PrecisionConfig ,
343- ) -> torch .Tensor :
344- max_num_tokens = round_up (a .shape [0 ], 64 )
345-
346- quant_config = FusedMoEQuantConfig .make (
347- w1_precision = w1_precision ,
348- w2_precision = w2_precision ,
349- w1_bias = w1_bias ,
350- w2_bias = w2_bias ,
351- )
352-
353- fused_experts = FusedMoEModularKernel (
354- BatchedPrepareAndFinalize (
355- max_num_tokens ,
356- num_dispatchers = 1 ,
357- num_local_experts = w1 .shape [0 ],
358- rank = 0 ,
359- ),
360- BatchedOAITritonExperts (
361- max_num_tokens = max_num_tokens ,
362- num_dispatchers = 1 ,
363- quant_config = quant_config ,
364- ),
365- )
366-
367- topk_weight , topk_ids , _ = fused_topk (a , gating_output , topk , renormalize )
368-
369- return fused_experts (
370- a ,
371- w1 ,
372- w2 ,
373- topk_weight ,
374- topk_ids ,
375- )
376-
377-
378- @pytest .mark .parametrize (
379- ", " .join (f .name for f in fields (Case )),
380- [
381- tuple (getattr (case , f .name ) for f in fields (Case ))
382- for case in [
383- # Case(a_dtype="bf16", w_dtype="bf16"),
384- # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
385- Case (a_dtype = "bf16" , w_dtype = "mx4" )
386- ]
387- ],
388- )
389- @pytest .mark .parametrize ("num_token" , [64 ])
390- @pytest .mark .parametrize ("ep" , [1 , 2 , 4 , 8 ])
391- def test_triton_kernel_batched_moe (num_token , a_dtype , w_dtype , ep ):
392- M = num_token
393- E = ModelConfig .num_experts // ep
394- K = ModelConfig .hidden_size
395- N = ModelConfig .intermediate_size
396- topk = ModelConfig .experts_per_token
397-
398- (
399- x ,
400- w1 ,
401- w1_bias ,
402- w2 ,
403- w2_bias ,
404- exp_data ,
405- x_tri ,
406- w1_tri ,
407- w2_tri ,
408- exp_data_tri ,
409- w1_bias_tri ,
410- w2_bias_tri ,
411- pc1 ,
412- pc2 ,
413- ) = init_compute_data (M , K , N , E , a_dtype , w_dtype , num_warps = 4 )
414-
415- out_tri = batched_moe (
416- a = x_tri ,
417- w1 = w1_tri ,
418- w2 = w2_tri ,
419- gating_output = exp_data_tri ,
420- topk = topk ,
421- renormalize = True ,
422- w1_bias = w1_bias_tri ,
423- w2_bias = w2_bias_tri ,
424- w1_precision = pc1 ,
425- w2_precision = pc2 ,
426- )
427- out_tri = out_tri [..., :K ]
428-
429- out_ref = oai_moe_forward (
430- hidden_states = x ,
431- w1 = w1 ,
432- w1_bias = w1_bias ,
433- w2 = w2 ,
434- w2_bias = w2_bias ,
435- gating_output = exp_data ,
436- topk = topk ,
437- )
438- assert_close (ref = out_ref , tri = out_tri , maxtol = 0.025 , rmstol = 0.005 )
439-
440-
441326def test_unit_shuffle ():
442327 N = ModelConfig .intermediate_size
443328 K = ModelConfig .hidden_size
0 commit comments