11# SPDX-License-Identifier: Apache-2.0
22"""
33Test DeepEP + DeepGEMM integration
4+ DeepGEMM are gemm kernels specialized for the
5+ fp8 block-quantized case.
46"""
57
68import dataclasses
3335if has_deep_ep :
3436 from vllm .model_executor .layers .fused_moe .deepep_ht_prepare_finalize import ( # noqa: E501
3537 DeepEPHTPrepareAndFinalize )
38+ from vllm .model_executor .layers .fused_moe .deepep_ll_prepare_finalize import ( # noqa: E501
39+ DeepEPLLPrepareAndFinalize )
3640
37- from .deepep_utils import DeepEPHTArgs , make_deepep_a2a
41+ from .deepep_utils import DeepEPHTArgs , DeepEPLLArgs , make_deepep_a2a
3842
3943if has_deep_gemm :
44+ from vllm .model_executor .layers .fused_moe .batched_deep_gemm_moe import (
45+ BatchedDeepGemmExperts )
4046 from vllm .model_executor .layers .fused_moe .deep_gemm_moe import (
4147 DeepGemmExperts )
4248
5359P = ParamSpec ("P" )
5460
5561
62+ def next_power_of_2 (x ):
63+ import math
64+ if x == 0 :
65+ return 1
66+ return 2 ** math .ceil (math .log2 (x ))
67+
68+
5669def per_block_cast_to_fp8 (
5770 x : torch .Tensor ,
5871 block_size_n : int = 128 ) -> tuple [torch .Tensor , torch .Tensor ]:
@@ -126,6 +139,9 @@ class TestConfig:
126139 n : int
127140 num_experts : int
128141 block_size : list [int ]
142+ # configs for testing low-latency kernels
143+ low_latency : bool
144+ use_fp8_dispatch : Optional [bool ] = False
129145
130146
131147@dataclasses .dataclass
@@ -170,9 +186,43 @@ def make(config: TestConfig, rank) -> "TestTensors":
170186 config = config )
171187
172188
173- def make_modular_kernel (pg : ProcessGroup , pgi : ProcessGroupInfo , dp_size : int ,
174- num_local_experts : int , q_dtype : Optional [torch .dtype ],
175- block_shape : list [int ]) -> FusedMoEModularKernel :
189+ def make_ll_modular_kernel (pg : ProcessGroup , pgi : ProcessGroupInfo ,
190+ max_tokens_per_rank : int , dp_size : int ,
191+ hidden_size : int , q_dtype : Optional [torch .dtype ],
192+ test_config : TestConfig ) -> FusedMoEModularKernel :
193+
194+ assert test_config .low_latency
195+ assert test_config .use_fp8_dispatch is not None
196+
197+ a2a : DeepEPLLPrepareAndFinalize = make_deepep_a2a (
198+ pg = pg ,
199+ pgi = pgi ,
200+ dp_size = dp_size ,
201+ deepep_ht_args = None ,
202+ deepep_ll_args = DeepEPLLArgs (
203+ max_tokens_per_rank = max_tokens_per_rank ,
204+ hidden_size = hidden_size ,
205+ num_experts = test_config .num_experts ,
206+ use_fp8_dispatch = test_config .use_fp8_dispatch ),
207+ q_dtype = q_dtype ,
208+ block_shape = test_config .block_size )
209+
210+ fused_experts = BatchedDeepGemmExperts (max_num_tokens = max_tokens_per_rank ,
211+ world_size = pgi .world_size ,
212+ dp_size = dp_size ,
213+ block_shape = test_config .block_size )
214+ mk = FusedMoEModularKernel (prepare_finalize = a2a ,
215+ fused_experts = fused_experts )
216+ return mk
217+
218+
219+ def make_ht_modular_kernel (pg : ProcessGroup , pgi : ProcessGroupInfo ,
220+ dp_size : int , num_local_experts : int ,
221+ q_dtype : Optional [torch .dtype ],
222+ test_config : TestConfig ) -> FusedMoEModularKernel :
223+
224+ assert not test_config .low_latency
225+ assert test_config .use_fp8_dispatch is None
176226
177227 a2a : DeepEPHTPrepareAndFinalize = make_deepep_a2a (
178228 pg = pg ,
@@ -181,20 +231,50 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
181231 deepep_ht_args = DeepEPHTArgs (num_local_experts = num_local_experts ),
182232 deepep_ll_args = None ,
183233 q_dtype = q_dtype ,
184- block_shape = block_shape )
234+ block_shape = test_config . block_size )
185235
186236 fused_experts = DeepGemmExperts ()
187237 mk = FusedMoEModularKernel (prepare_finalize = a2a ,
188238 fused_experts = fused_experts )
189239 return mk
190240
191241
192- def deep_ep_moe_impl (pg : ProcessGroup , pgi : ProcessGroupInfo , dp_size : int ,
193- test_tensors : TestTensors , w1 : torch .Tensor ,
194- w2 : torch .Tensor , w1_scale : Optional [torch .Tensor ],
195- w2_scale : Optional [torch .Tensor ],
196- num_experts : int ) -> torch .Tensor :
242+ def make_modular_kernel (pg : ProcessGroup , pgi : ProcessGroupInfo , dp_size : int ,
243+ num_local_experts : int ,
244+ test_tensors : TestTensors ) -> FusedMoEModularKernel :
245+
246+ q_dtype = torch .float8_e4m3fn
247+ test_config = test_tensors .config
248+
249+ mk : FusedMoEModularKernel
250+ # Make modular kernel
251+ if test_config .low_latency :
252+ max_tokens_per_rank = max (
253+ 64 , next_power_of_2 (test_tensors .rank_tokens .size (0 )))
254+ hidden_size = test_tensors .rank_tokens .size (- 1 )
255+
256+ mk = make_ll_modular_kernel (pg = pg ,
257+ pgi = pgi ,
258+ max_tokens_per_rank = max_tokens_per_rank ,
259+ dp_size = dp_size ,
260+ hidden_size = hidden_size ,
261+ q_dtype = q_dtype ,
262+ test_config = test_config )
263+ else :
264+ mk = make_ht_modular_kernel (pg , pgi , dp_size , num_local_experts ,
265+ q_dtype , test_config )
266+
267+ return mk
268+
269+
270+ def deepep_deepgemm_moe_impl (pg : ProcessGroup , pgi : ProcessGroupInfo ,
271+ dp_size : int , test_tensors : TestTensors ,
272+ w1 : torch .Tensor , w2 : torch .Tensor ,
273+ w1_scale : Optional [torch .Tensor ],
274+ w2_scale : Optional [torch .Tensor ]) -> torch .Tensor :
197275
276+ test_config = test_tensors .config
277+ num_experts = test_config .num_experts
198278 num_local_experts = w1 .size (0 )
199279
200280 def build_expert_map ():
@@ -208,14 +288,17 @@ def build_expert_map():
208288 return expert_map .to (device = torch .cuda .current_device (),
209289 dtype = torch .int32 )
210290
211- q_dtype = torch .float8_e4m3fn
212-
213291 # Make modular kernel
214292 mk : FusedMoEModularKernel = make_modular_kernel (
215- pg , pgi , dp_size , num_local_experts , q_dtype ,
216- test_tensors .config .block_size )
293+ pg = pg ,
294+ pgi = pgi ,
295+ dp_size = dp_size ,
296+ num_local_experts = num_local_experts ,
297+ test_tensors = test_tensors )
217298
218- a1_scale = test_tensors .rank_token_scales
299+ # Low-Latency kernels can't dispatch scales.
300+ a1_scale = (None
301+ if test_config .low_latency else test_tensors .rank_token_scales )
219302
220303 out = mk .forward (hidden_states = test_tensors .rank_tokens ,
221304 w1 = w1 ,
@@ -258,7 +341,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
258341 allow_deep_gemm = False )
259342
260343
261- def _deep_ep_moe (
344+ def _test_deepep_deepgemm_moe (
262345 pgi : ProcessGroupInfo ,
263346 dp_size : int ,
264347 config : TestConfig ,
@@ -302,7 +385,7 @@ def _deep_ep_moe(
302385 w1_scale_ep = w1_scale [e_start :e_end ]
303386 w2_scale_ep = w2_scale [e_start :e_end ]
304387
305- deepep_moe = deep_ep_moe_impl (
388+ deepep_moe = deepep_deepgemm_moe_impl (
306389 pg ,
307390 pgi ,
308391 dp_size ,
@@ -311,7 +394,6 @@ def _deep_ep_moe(
311394 w2_ep ,
312395 w1_scale_ep ,
313396 w2_scale_ep ,
314- config .num_experts ,
315397 )
316398
317399 torch .testing .assert_close (
@@ -335,15 +417,21 @@ def _deep_ep_moe(
335417 (222 , 1024 , 2048 ),
336418]
337419
420+ TOPKS = [2 , 6 ]
421+ NUM_EXPERTS = [32 ]
422+
338423
339424@pytest .mark .parametrize ("mnk" , MNKs )
340- @pytest .mark .parametrize ("num_experts" , [ 32 ] )
341- @pytest .mark .parametrize ("topk" , [ 2 , 6 ] )
425+ @pytest .mark .parametrize ("num_experts" , NUM_EXPERTS )
426+ @pytest .mark .parametrize ("topk" , TOPKS )
342427@pytest .mark .parametrize ("world_dp_size" , [(2 , 1 )])
343428@requires_deep_ep
344429@requires_deep_gemm
345- def test_deep_ep_moe (mnk : tuple [int , int , int ], num_experts : int , topk : int ,
346- world_dp_size : tuple [int , int ]):
430+ def test_ht_deepep_deepgemm_moe (mnk : tuple [int , int , int ], num_experts : int ,
431+ topk : int , world_dp_size : tuple [int , int ]):
432+ """
433+ Tests for High-Throughput DeepEP + DeepGemm integration.
434+ """
347435
348436 m , n , k = mnk
349437 current_platform .seed_everything (7 )
@@ -354,6 +442,58 @@ def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
354442 block_m = deep_gemm .get_m_alignment_for_contiguous_layout ()
355443 block_size = [block_m , block_m ]
356444
445+ world_size , dp_size = world_dp_size
446+ config = TestConfig (topk = topk ,
447+ m = m ,
448+ k = k ,
449+ n = n ,
450+ num_experts = num_experts ,
451+ block_size = block_size ,
452+ low_latency = False ,
453+ use_fp8_dispatch = None )
454+
455+ w1 , w2 , w1_scale , w2_scale = make_block_quant_fp8_weights (
456+ num_experts , n , k , block_size )
457+
458+ parallel_launch (world_size , _test_deepep_deepgemm_moe , dp_size , config , w1 ,
459+ w2 , w1_scale , w2_scale )
460+
461+
462+ MNKs = [
463+ (1 , 128 , 2560 ),
464+ (2 , 128 , 2560 ),
465+ (3 , 1024 , 2560 ),
466+ (32 , 128 , 2560 ),
467+ (45 , 512 , 2560 ),
468+ (64 , 1024 , 2560 ),
469+ (222 , 1024 , 2560 ),
470+ ]
471+ # Fix tests for USE_FP8_DISPATCH=True
472+ USE_FP8_DISPATCH = [False ]
473+
474+
475+ @pytest .mark .parametrize ("mnk" , MNKs )
476+ @pytest .mark .parametrize ("num_experts" , NUM_EXPERTS )
477+ @pytest .mark .parametrize ("topk" , TOPKS )
478+ @pytest .mark .parametrize ("use_fp8_dispatch" , USE_FP8_DISPATCH )
479+ @pytest .mark .parametrize ("block_size" , [[128 , 128 ]])
480+ @pytest .mark .parametrize ("world_dp_size" , [(2 , 1 )])
481+ @requires_deep_ep
482+ @requires_deep_gemm
483+ def test_ll_deepep_deepgemm_moe (mnk : tuple [int , int ,
484+ int ], num_experts : int , topk : int ,
485+ use_fp8_dispatch : bool , block_size : list [int ],
486+ world_dp_size : tuple [int , int ]):
487+ """
488+ Tests for Low-Latency DeepEP + DeepGemm integration.
489+ """
490+
491+ m , n , k = mnk
492+ current_platform .seed_everything (7 )
493+
494+ if topk > num_experts :
495+ pytest .skip (f"Skipping test: topk={ topk } > E={ num_experts } " )
496+
357497 world_size , dp_size = world_dp_size
358498 config = TestConfig (
359499 topk = topk ,
@@ -362,10 +502,12 @@ def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
362502 n = n ,
363503 num_experts = num_experts ,
364504 block_size = block_size ,
505+ low_latency = True ,
506+ use_fp8_dispatch = use_fp8_dispatch ,
365507 )
366508
367509 w1 , w2 , w1_scale , w2_scale = make_block_quant_fp8_weights (
368510 num_experts , n , k , block_size )
369511
370- parallel_launch (world_size , _deep_ep_moe , dp_size , config , w1 , w2 ,
371- w1_scale , w2_scale )
512+ parallel_launch (world_size , _test_deepep_deepgemm_moe , dp_size , config , w1 ,
513+ w2 , w1_scale , w2_scale )
0 commit comments