Skip to content

Commit c3fd4d6

Browse files
varun-sundar-rabindranathVarun
andauthored
[Kernel] Integrate batched/masked deepgemm kernel (#19111)
Signed-off-by: Varun <vsundarr@redhat.com> Co-authored-by: Varun <vsundarr@redhat.com>
1 parent ef3f98b commit c3fd4d6

File tree

6 files changed

+472
-51
lines changed

6 files changed

+472
-51
lines changed

tests/kernels/moe/deepep_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,14 @@ def make_deepep_ll_a2a(pg: ProcessGroup,
162162
low_latency_mode=True,
163163
num_qps_per_rank=deepep_ll_args.num_experts //
164164
pgi.world_size)
165+
165166
return DeepEPLLPrepareAndFinalize(
166167
buffer=buffer,
167168
world_size=pgi.world_size,
168169
dp_size=dp_size,
169170
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
170171
quant_dtype=q_dtype,
172+
block_shape=block_shape,
171173
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
172174
)
173175

@@ -185,4 +187,5 @@ def make_deepep_a2a(pg: ProcessGroup,
185187
block_shape)
186188

187189
assert deepep_ll_args is not None
188-
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype)
190+
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
191+
block_shape)

tests/kernels/moe/test_deepep_deepgemm_moe.py

Lines changed: 166 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""
33
Test DeepEP + DeepGEMM integration
4+
DeepGEMM are gemm kernels specialized for the
5+
fp8 block-quantized case.
46
"""
57

68
import dataclasses
@@ -33,10 +35,14 @@
3335
if has_deep_ep:
3436
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
3537
DeepEPHTPrepareAndFinalize)
38+
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
39+
DeepEPLLPrepareAndFinalize)
3640

37-
from .deepep_utils import DeepEPHTArgs, make_deepep_a2a
41+
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
3842

3943
if has_deep_gemm:
44+
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
45+
BatchedDeepGemmExperts)
4046
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
4147
DeepGemmExperts)
4248

@@ -53,6 +59,13 @@
5359
P = ParamSpec("P")
5460

5561

62+
def next_power_of_2(x):
63+
import math
64+
if x == 0:
65+
return 1
66+
return 2**math.ceil(math.log2(x))
67+
68+
5669
def per_block_cast_to_fp8(
5770
x: torch.Tensor,
5871
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
@@ -126,6 +139,9 @@ class TestConfig:
126139
n: int
127140
num_experts: int
128141
block_size: list[int]
142+
# configs for testing low-latency kernels
143+
low_latency: bool
144+
use_fp8_dispatch: Optional[bool] = False
129145

130146

131147
@dataclasses.dataclass
@@ -170,9 +186,43 @@ def make(config: TestConfig, rank) -> "TestTensors":
170186
config=config)
171187

172188

173-
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
174-
num_local_experts: int, q_dtype: Optional[torch.dtype],
175-
block_shape: list[int]) -> FusedMoEModularKernel:
189+
def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
190+
max_tokens_per_rank: int, dp_size: int,
191+
hidden_size: int, q_dtype: Optional[torch.dtype],
192+
test_config: TestConfig) -> FusedMoEModularKernel:
193+
194+
assert test_config.low_latency
195+
assert test_config.use_fp8_dispatch is not None
196+
197+
a2a: DeepEPLLPrepareAndFinalize = make_deepep_a2a(
198+
pg=pg,
199+
pgi=pgi,
200+
dp_size=dp_size,
201+
deepep_ht_args=None,
202+
deepep_ll_args=DeepEPLLArgs(
203+
max_tokens_per_rank=max_tokens_per_rank,
204+
hidden_size=hidden_size,
205+
num_experts=test_config.num_experts,
206+
use_fp8_dispatch=test_config.use_fp8_dispatch),
207+
q_dtype=q_dtype,
208+
block_shape=test_config.block_size)
209+
210+
fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank,
211+
world_size=pgi.world_size,
212+
dp_size=dp_size,
213+
block_shape=test_config.block_size)
214+
mk = FusedMoEModularKernel(prepare_finalize=a2a,
215+
fused_experts=fused_experts)
216+
return mk
217+
218+
219+
def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
220+
dp_size: int, num_local_experts: int,
221+
q_dtype: Optional[torch.dtype],
222+
test_config: TestConfig) -> FusedMoEModularKernel:
223+
224+
assert not test_config.low_latency
225+
assert test_config.use_fp8_dispatch is None
176226

177227
a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
178228
pg=pg,
@@ -181,20 +231,50 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
181231
deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
182232
deepep_ll_args=None,
183233
q_dtype=q_dtype,
184-
block_shape=block_shape)
234+
block_shape=test_config.block_size)
185235

186236
fused_experts = DeepGemmExperts()
187237
mk = FusedMoEModularKernel(prepare_finalize=a2a,
188238
fused_experts=fused_experts)
189239
return mk
190240

191241

192-
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
193-
test_tensors: TestTensors, w1: torch.Tensor,
194-
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
195-
w2_scale: Optional[torch.Tensor],
196-
num_experts: int) -> torch.Tensor:
242+
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
243+
num_local_experts: int,
244+
test_tensors: TestTensors) -> FusedMoEModularKernel:
245+
246+
q_dtype = torch.float8_e4m3fn
247+
test_config = test_tensors.config
248+
249+
mk: FusedMoEModularKernel
250+
# Make modular kernel
251+
if test_config.low_latency:
252+
max_tokens_per_rank = max(
253+
64, next_power_of_2(test_tensors.rank_tokens.size(0)))
254+
hidden_size = test_tensors.rank_tokens.size(-1)
255+
256+
mk = make_ll_modular_kernel(pg=pg,
257+
pgi=pgi,
258+
max_tokens_per_rank=max_tokens_per_rank,
259+
dp_size=dp_size,
260+
hidden_size=hidden_size,
261+
q_dtype=q_dtype,
262+
test_config=test_config)
263+
else:
264+
mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts,
265+
q_dtype, test_config)
266+
267+
return mk
268+
269+
270+
def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
271+
dp_size: int, test_tensors: TestTensors,
272+
w1: torch.Tensor, w2: torch.Tensor,
273+
w1_scale: Optional[torch.Tensor],
274+
w2_scale: Optional[torch.Tensor]) -> torch.Tensor:
197275

276+
test_config = test_tensors.config
277+
num_experts = test_config.num_experts
198278
num_local_experts = w1.size(0)
199279

200280
def build_expert_map():
@@ -208,14 +288,17 @@ def build_expert_map():
208288
return expert_map.to(device=torch.cuda.current_device(),
209289
dtype=torch.int32)
210290

211-
q_dtype = torch.float8_e4m3fn
212-
213291
# Make modular kernel
214292
mk: FusedMoEModularKernel = make_modular_kernel(
215-
pg, pgi, dp_size, num_local_experts, q_dtype,
216-
test_tensors.config.block_size)
293+
pg=pg,
294+
pgi=pgi,
295+
dp_size=dp_size,
296+
num_local_experts=num_local_experts,
297+
test_tensors=test_tensors)
217298

218-
a1_scale = test_tensors.rank_token_scales
299+
# Low-Latency kernels can't dispatch scales.
300+
a1_scale = (None
301+
if test_config.low_latency else test_tensors.rank_token_scales)
219302

220303
out = mk.forward(hidden_states=test_tensors.rank_tokens,
221304
w1=w1,
@@ -258,7 +341,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
258341
allow_deep_gemm=False)
259342

260343

261-
def _deep_ep_moe(
344+
def _test_deepep_deepgemm_moe(
262345
pgi: ProcessGroupInfo,
263346
dp_size: int,
264347
config: TestConfig,
@@ -302,7 +385,7 @@ def _deep_ep_moe(
302385
w1_scale_ep = w1_scale[e_start:e_end]
303386
w2_scale_ep = w2_scale[e_start:e_end]
304387

305-
deepep_moe = deep_ep_moe_impl(
388+
deepep_moe = deepep_deepgemm_moe_impl(
306389
pg,
307390
pgi,
308391
dp_size,
@@ -311,7 +394,6 @@ def _deep_ep_moe(
311394
w2_ep,
312395
w1_scale_ep,
313396
w2_scale_ep,
314-
config.num_experts,
315397
)
316398

317399
torch.testing.assert_close(
@@ -335,15 +417,21 @@ def _deep_ep_moe(
335417
(222, 1024, 2048),
336418
]
337419

420+
TOPKS = [2, 6]
421+
NUM_EXPERTS = [32]
422+
338423

339424
@pytest.mark.parametrize("mnk", MNKs)
340-
@pytest.mark.parametrize("num_experts", [32])
341-
@pytest.mark.parametrize("topk", [2, 6])
425+
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
426+
@pytest.mark.parametrize("topk", TOPKS)
342427
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
343428
@requires_deep_ep
344429
@requires_deep_gemm
345-
def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
346-
world_dp_size: tuple[int, int]):
430+
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
431+
topk: int, world_dp_size: tuple[int, int]):
432+
"""
433+
Tests for High-Throughput DeepEP + DeepGemm integration.
434+
"""
347435

348436
m, n, k = mnk
349437
current_platform.seed_everything(7)
@@ -354,6 +442,58 @@ def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
354442
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
355443
block_size = [block_m, block_m]
356444

445+
world_size, dp_size = world_dp_size
446+
config = TestConfig(topk=topk,
447+
m=m,
448+
k=k,
449+
n=n,
450+
num_experts=num_experts,
451+
block_size=block_size,
452+
low_latency=False,
453+
use_fp8_dispatch=None)
454+
455+
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
456+
num_experts, n, k, block_size)
457+
458+
parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
459+
w2, w1_scale, w2_scale)
460+
461+
462+
MNKs = [
463+
(1, 128, 2560),
464+
(2, 128, 2560),
465+
(3, 1024, 2560),
466+
(32, 128, 2560),
467+
(45, 512, 2560),
468+
(64, 1024, 2560),
469+
(222, 1024, 2560),
470+
]
471+
# Fix tests for USE_FP8_DISPATCH=True
472+
USE_FP8_DISPATCH = [False]
473+
474+
475+
@pytest.mark.parametrize("mnk", MNKs)
476+
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
477+
@pytest.mark.parametrize("topk", TOPKS)
478+
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
479+
@pytest.mark.parametrize("block_size", [[128, 128]])
480+
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
481+
@requires_deep_ep
482+
@requires_deep_gemm
483+
def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
484+
int], num_experts: int, topk: int,
485+
use_fp8_dispatch: bool, block_size: list[int],
486+
world_dp_size: tuple[int, int]):
487+
"""
488+
Tests for Low-Latency DeepEP + DeepGemm integration.
489+
"""
490+
491+
m, n, k = mnk
492+
current_platform.seed_everything(7)
493+
494+
if topk > num_experts:
495+
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
496+
357497
world_size, dp_size = world_dp_size
358498
config = TestConfig(
359499
topk=topk,
@@ -362,10 +502,12 @@ def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
362502
n=n,
363503
num_experts=num_experts,
364504
block_size=block_size,
505+
low_latency=True,
506+
use_fp8_dispatch=use_fp8_dispatch,
365507
)
366508

367509
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
368510
num_experts, n, k, block_size)
369511

370-
parallel_launch(world_size, _deep_ep_moe, dp_size, config, w1, w2,
371-
w1_scale, w2_scale)
512+
parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
513+
w2, w1_scale, w2_scale)

0 commit comments

Comments
 (0)