11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4+ from typing import Optional
5+
46import pytest
57import torch
68
7- from tests .pplx_utils import ProcessGroupInfo , parallel_launch
89from vllm import _custom_ops as ops
910from vllm .config import VllmConfig , set_current_vllm_config
1011from vllm .model_executor .layers .activation import SiluAndMul
1415 FusedMoEModularKernel )
1516from vllm .platforms import current_platform
1617
18+ from .deepep_utils import ProcessGroupInfo , parallel_launch
19+
1720try :
1821 from pplx_kernels import AllToAll
1922 from pplx_kernels .nvshmem import (nvshmem_alloc_empty_unique_id ,
@@ -64,6 +67,7 @@ def pplx_cutlass_moe(
6467 out_dtype ,
6568 per_act_token : bool ,
6669 per_out_ch : bool ,
70+ group_name : Optional [str ],
6771):
6872 from vllm .model_executor .layers .fused_moe .pplx_prepare_finalize import (
6973 PplxPrepareAndFinalize )
@@ -84,7 +88,7 @@ def pplx_cutlass_moe(
8488 else :
8589 scale_elems = (hidden_dim + block_size - 1 ) // block_size
8690
87- ata = AllToAll . internode (
91+ args = dict (
8892 max_num_tokens = max_num_tokens ,
8993 num_experts = num_experts ,
9094 experts_per_token = topk ,
@@ -96,6 +100,12 @@ def pplx_cutlass_moe(
96100 hidden_dim_scale_bytes = scale_elems * torch .float32 .itemsize ,
97101 )
98102
103+ if group_name is None :
104+ ata = AllToAll .internode (** args )
105+ else :
106+ args ["group_name" ] = group_name
107+ ata = AllToAll .intranode (** args )
108+
99109 w1 = w1 .to (device )
100110 w2 = w2 .to (device )
101111 w1_scale = w1_scale .to (device )
@@ -113,7 +123,10 @@ def pplx_cutlass_moe(
113123 )
114124
115125 experts = CutlassExpertsFp8 ((num_experts + world_size - 1 ) // world_size ,
116- out_dtype , per_act_token , per_out_ch )
126+ out_dtype ,
127+ per_act_token ,
128+ per_out_ch ,
129+ use_batched_format = True )
117130
118131 fused_cutlass_experts = FusedMoEModularKernel (
119132 prepare_finalize ,
@@ -184,19 +197,25 @@ def _pplx_moe(
184197 w2_full : torch .Tensor ,
185198 per_act_token : bool ,
186199 per_out_ch : bool ,
200+ use_internode : bool ,
187201):
188- uid = nvshmem_get_unique_id (
189- ) if pgi .rank == 0 else nvshmem_alloc_empty_unique_id ()
190- torch .distributed .broadcast (uid , src = 0 )
191- nvshmem_init (uid , pgi .rank , pgi .world_size )
202+ if use_internode :
203+ uid = nvshmem_get_unique_id (
204+ ) if pgi .rank == 0 else nvshmem_alloc_empty_unique_id ()
205+ torch .distributed .broadcast (uid , src = 0 )
206+ nvshmem_init (uid , pgi .rank , pgi .world_size )
207+ else :
208+ group_ranks = list (range (pgi .world_size ))
209+ cpu_group = torch .distributed .new_group (group_ranks , backend = "gloo" )
210+ group_name = cpu_group .group_name
192211
193212 with set_current_vllm_config (vllm_config ):
194213 torch_output = torch_moe2 (a_full , w1_full , w2_full , topk_weights ,
195214 topk_ids )
196215 pplx_output = pplx_cutlass_moe (pgi , dp_size , a , w1 , w2 , w1_scale ,
197216 w2_scale , topk_weights , topk_ids ,
198217 a1_scale , out_dtype , per_act_token ,
199- per_out_ch )
218+ per_out_ch , group_name )
200219
201220 torch_output = chunk_by_rank (torch_output , pgi .rank ,
202221 pgi .world_size ).to (pplx_output .device )
@@ -207,7 +226,8 @@ def _pplx_moe(
207226
208227 torch .testing .assert_close (pplx_output , torch_output , atol = 0.05 , rtol = 0 )
209228
210- nvshmem_finalize ()
229+ if use_internode :
230+ nvshmem_finalize ()
211231
212232
213233@pytest .mark .parametrize ("m" , [2 , 224 ])
@@ -218,6 +238,7 @@ def _pplx_moe(
218238@pytest .mark .parametrize ("per_act_token" , [True , False ])
219239@pytest .mark .parametrize ("per_out_ch" , [True , False ])
220240@pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]]) #, [4, 2]])
241+ @pytest .mark .parametrize ("use_internode" , [False ])
221242@pytest .mark .skipif (
222243 (lambda x : x is None or not ops .cutlass_group_gemm_supported (x .to_int ()))(
223244 current_platform .get_device_capability ()),
@@ -232,6 +253,7 @@ def test_cutlass_moe_pplx(
232253 per_act_token : bool ,
233254 per_out_ch : bool ,
234255 world_dp_size : tuple [int , int ],
256+ use_internode : bool ,
235257):
236258 current_platform .seed_everything (7 )
237259
@@ -284,4 +306,5 @@ def test_cutlass_moe_pplx(
284306
285307 parallel_launch (world_size , _pplx_moe , dp_size , a , w1_q , w2_q ,
286308 w1_scale , w2_scale , topk_weights , topk_ids , a_scale1 ,
287- dtype , a , w1_d , w2_d , per_act_token , per_out_ch )
309+ dtype , a , w1_d , w2_d , per_act_token , per_out_ch ,
310+ use_internode )
0 commit comments