2626from vllm .scalar_type import scalar_types
2727
2828NUM_EXPERTS = [8 , 64 ]
29+ EP_SIZE = [1 , 4 ]
2930TOP_KS = [2 , 6 ]
3031
3132
3435@pytest .mark .parametrize ("k" , [128 , 511 , 1024 ])
3536@pytest .mark .parametrize ("e" , NUM_EXPERTS )
3637@pytest .mark .parametrize ("topk" , TOP_KS )
38+ @pytest .mark .parametrize ("ep_size" , EP_SIZE )
3739@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
3840def test_fused_moe (
3941 m : int ,
4042 n : int ,
4143 k : int ,
4244 e : int ,
4345 topk : int ,
46+ ep_size : int ,
4447 dtype : torch .dtype ,
4548):
4649 a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
4750 w1 = torch .randn ((e , 2 * n , k ), device = "cuda" , dtype = dtype ) / 10
4851 w2 = torch .randn ((e , k , n ), device = "cuda" , dtype = dtype ) / 10
4952
5053 score = torch .randn ((m , e ), device = "cuda" , dtype = dtype )
51- triton_output = fused_moe (a , w1 , w2 , score , topk , renormalize = False )
52- torch_output = torch_moe (a , w1 , w2 , score , topk )
54+
55+ if ep_size > 1 :
56+ local_e = e // ep_size
57+ e_ids = torch .randint (0 ,
58+ e , (local_e , ),
59+ device = "cuda" ,
60+ dtype = torch .int32 )
61+ e_map = torch .full ((e , ), - 1 , device = "cuda" , dtype = torch .int32 )
62+ e_map [e_ids ] = torch .arange (local_e , device = "cuda" , dtype = torch .int32 )
63+ w1 = w1 [e_ids ]
64+ w2 = w2 [e_ids ]
65+ else :
66+ e_map = None
67+
68+ triton_output = fused_moe (a ,
69+ w1 ,
70+ w2 ,
71+ score ,
72+ topk ,
73+ global_num_experts = e ,
74+ expert_map = e_map ,
75+ renormalize = False )
76+ torch_output = torch_moe (a , w1 , w2 , score , topk , e_map )
5377 torch .testing .assert_close (triton_output , torch_output , atol = 2e-2 , rtol = 0 )
54- iterative_output = iterative_moe (a , w1 , w2 , score , topk , renormalize = False )
78+ iterative_output = iterative_moe (a ,
79+ w1 ,
80+ w2 ,
81+ score ,
82+ topk ,
83+ global_num_experts = e ,
84+ expert_map = e_map ,
85+ renormalize = False )
5586 torch .testing .assert_close (iterative_output ,
5687 torch_output ,
5788 atol = 2e-2 ,
@@ -63,13 +94,14 @@ def test_fused_moe(
6394@pytest .mark .parametrize ("k" , [128 , 1024 ])
6495@pytest .mark .parametrize ("e" , NUM_EXPERTS )
6596@pytest .mark .parametrize ("topk" , TOP_KS )
97+ @pytest .mark .parametrize ("ep_size" , EP_SIZE )
6698@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
6799@pytest .mark .parametrize ("group_size" , [64 , 128 ])
68100@pytest .mark .parametrize ("has_zp" , [True , False ])
69101@pytest .mark .parametrize ("weight_bits" , [4 , 8 ])
70102def test_fused_moe_wn16 (m : int , n : int , k : int , e : int , topk : int ,
71- dtype : torch .dtype , group_size : int , has_zp : bool ,
72- weight_bits : int ):
103+ ep_size : int , dtype : torch .dtype , group_size : int ,
104+ has_zp : bool , weight_bits : int ):
73105 print (m , n , k , e , topk , dtype , group_size , has_zp , weight_bits )
74106 a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
75107 w1 = torch .randn ((e , 2 * n , k ), device = "cuda" , dtype = dtype ) / 10
@@ -130,6 +162,25 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
130162 if has_zp :
131163 w_qzeros [expert_id ] = qzeros
132164
165+ if ep_size > 1 :
166+ local_e = e // ep_size
167+ e_ids = torch .randint (0 ,
168+ e , (local_e , ),
169+ device = "cuda" ,
170+ dtype = torch .int32 )
171+ e_map = torch .full ((e , ), - 1 , device = "cuda" , dtype = torch .int32 )
172+ e_map [e_ids ] = torch .arange (local_e , device = "cuda" , dtype = torch .int32 )
173+ w1_ref = w1_ref [e_ids ]
174+ w2_ref = w2_ref [e_ids ]
175+ w1_qweight = w1_qweight [e_ids ]
176+ w2_qweight = w2_qweight [e_ids ]
177+ w1_scales = w1_scales [e_ids ]
178+ w2_scales = w2_scales [e_ids ]
179+ w1_qzeros = w1_qzeros [e_ids ]
180+ w2_qzeros = w2_qzeros [e_ids ]
181+ else :
182+ e_map = None
183+
133184 triton_output = fused_moe (a ,
134185 w1_qweight ,
135186 w2_qweight ,
@@ -138,12 +189,14 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
138189 renormalize = False ,
139190 use_int4_w4a16 = weight_bits == 4 ,
140191 use_int8_w8a16 = weight_bits == 8 ,
192+ global_num_experts = e ,
193+ expert_map = e_map ,
141194 w1_scale = w1_scales ,
142195 w2_scale = w2_scales ,
143196 w1_zp = w1_qzeros if has_zp else None ,
144197 w2_zp = w2_qzeros if has_zp else None ,
145198 block_shape = [0 , group_size ])
146- torch_output = torch_moe (a , w1_ref , w2_ref , score , topk )
199+ torch_output = torch_moe (a , w1_ref , w2_ref , score , topk , e_map )
147200 torch .testing .assert_close (triton_output , torch_output , atol = 2e-2 , rtol = 0 )
148201
149202
0 commit comments