Skip to content

Commit a9f55dc

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Misc] Add triton_kernels dependency (#27370)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
1 parent 81d5bb7 commit a9f55dc

File tree

2 files changed

+4
-117
lines changed

2 files changed

+4
-117
lines changed

requirements/cuda.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@ torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytor
1313
# xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8
1414
# FlashInfer should be updated together with the Dockerfile
1515
flashinfer-python==0.4.1
16+
# Triton Kernels are needed for mxfp4 fused moe. (Should be updated alongside torch)
17+
triton_kernels @ git+https://github.com/triton-lang/triton.git@v3.5.0#subdirectory=python/triton_kernels

tests/kernels/moe/test_gpt_oss_triton_kernels.py

Lines changed: 2 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,9 @@
2323
from triton_kernels.testing import assert_close
2424

2525
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
26-
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
27-
BatchedPrepareAndFinalize,
28-
)
29-
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
3026
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
31-
BatchedOAITritonExperts,
3227
triton_kernel_moe_forward,
3328
)
34-
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
3529
from vllm.model_executor.layers.utils import shuffle_weight
3630
from vllm.utils import round_up
3731

@@ -302,8 +296,8 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
302296
quant_config = FusedMoEQuantConfig.make(
303297
w1_bias=w1_bias_tri,
304298
w2_bias=w2_bias_tri,
305-
w1_precision=pc1,
306-
w2_precision=pc2,
299+
w1_scale=pc1,
300+
w2_scale=pc2,
307301
)
308302

309303
out_triton_monolithic = triton_kernel_moe_forward(
@@ -329,115 +323,6 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
329323
assert_close(ref=out_ref, tri=out_triton_monolithic, maxtol=0.025, rmstol=0.005)
330324

331325

332-
def batched_moe(
333-
a: torch.Tensor,
334-
w1,
335-
w2,
336-
gating_output: torch.Tensor,
337-
topk: int,
338-
renormalize: bool,
339-
w1_bias: torch.Tensor,
340-
w2_bias: torch.Tensor,
341-
w1_precision: PrecisionConfig,
342-
w2_precision: PrecisionConfig,
343-
) -> torch.Tensor:
344-
max_num_tokens = round_up(a.shape[0], 64)
345-
346-
quant_config = FusedMoEQuantConfig.make(
347-
w1_precision=w1_precision,
348-
w2_precision=w2_precision,
349-
w1_bias=w1_bias,
350-
w2_bias=w2_bias,
351-
)
352-
353-
fused_experts = FusedMoEModularKernel(
354-
BatchedPrepareAndFinalize(
355-
max_num_tokens,
356-
num_dispatchers=1,
357-
num_local_experts=w1.shape[0],
358-
rank=0,
359-
),
360-
BatchedOAITritonExperts(
361-
max_num_tokens=max_num_tokens,
362-
num_dispatchers=1,
363-
quant_config=quant_config,
364-
),
365-
)
366-
367-
topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize)
368-
369-
return fused_experts(
370-
a,
371-
w1,
372-
w2,
373-
topk_weight,
374-
topk_ids,
375-
)
376-
377-
378-
@pytest.mark.parametrize(
379-
", ".join(f.name for f in fields(Case)),
380-
[
381-
tuple(getattr(case, f.name) for f in fields(Case))
382-
for case in [
383-
# Case(a_dtype="bf16", w_dtype="bf16"),
384-
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
385-
Case(a_dtype="bf16", w_dtype="mx4")
386-
]
387-
],
388-
)
389-
@pytest.mark.parametrize("num_token", [64])
390-
@pytest.mark.parametrize("ep", [1, 2, 4, 8])
391-
def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep):
392-
M = num_token
393-
E = ModelConfig.num_experts // ep
394-
K = ModelConfig.hidden_size
395-
N = ModelConfig.intermediate_size
396-
topk = ModelConfig.experts_per_token
397-
398-
(
399-
x,
400-
w1,
401-
w1_bias,
402-
w2,
403-
w2_bias,
404-
exp_data,
405-
x_tri,
406-
w1_tri,
407-
w2_tri,
408-
exp_data_tri,
409-
w1_bias_tri,
410-
w2_bias_tri,
411-
pc1,
412-
pc2,
413-
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=4)
414-
415-
out_tri = batched_moe(
416-
a=x_tri,
417-
w1=w1_tri,
418-
w2=w2_tri,
419-
gating_output=exp_data_tri,
420-
topk=topk,
421-
renormalize=True,
422-
w1_bias=w1_bias_tri,
423-
w2_bias=w2_bias_tri,
424-
w1_precision=pc1,
425-
w2_precision=pc2,
426-
)
427-
out_tri = out_tri[..., :K]
428-
429-
out_ref = oai_moe_forward(
430-
hidden_states=x,
431-
w1=w1,
432-
w1_bias=w1_bias,
433-
w2=w2,
434-
w2_bias=w2_bias,
435-
gating_output=exp_data,
436-
topk=topk,
437-
)
438-
assert_close(ref=out_ref, tri=out_tri, maxtol=0.025, rmstol=0.005)
439-
440-
441326
def test_unit_shuffle():
442327
N = ModelConfig.intermediate_size
443328
K = ModelConfig.hidden_size

0 commit comments

Comments
 (0)