From 7af2c101327647fa58c6dfbd1941004e30f099ae Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 25 Jun 2025 18:21:34 +0000 Subject: [PATCH 1/8] add unit test for deep gemm Signed-off-by: yewentao256 --- tests/kernels/moe/test_deepgemm.py | 215 +++++++++++++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 tests/kernels/moe/test_deepgemm.py diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py new file mode 100644 index 000000000000..b03fc1c96e69 --- /dev/null +++ b/tests/kernels/moe/test_deepgemm.py @@ -0,0 +1,215 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Unit-test DeepGEMM FP8 kernels (no DeepEP). +Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts. +""" + +import importlib +import math + +import pytest +import torch + +# vLLM fused-expert reference (Triton fallback + DeepGEMM option) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.utils import cdiv + +has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None + +if has_deep_gemm: + import deep_gemm + BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout() + BLOCK_SIZE = [BLOCK_M, BLOCK_M] + +requires_deep_gemm = pytest.mark.skipif( + not has_deep_gemm, + reason="Requires deep_gemm kernels", +) + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def per_block_cast_to_fp8( + x: torch.Tensor, + block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() + scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + return x_scaled_sub, scales + + +def make_block_quant_fp8_weights( + e: int, + n: int, + k: int, + block_size: list[int], +): + """ + Generate (w1, w2) expert weights and their per-block scale tensors + in FP8 block-quantized format. + + w1 shape: (E, 2N, K) + w2 shape: (E, K, N) + """ + dtype = torch.bfloat16 + fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo( + torch.float8_e4m3fn).min + + # bf16 reference weights + w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10 + w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) / 10 + w1_bf16.clamp_(fp8_min, fp8_max) + w2_bf16.clamp_(fp8_min, fp8_max) + + block_n, block_k = block_size + n_tiles_w1 = math.ceil((2 * n) / block_n) + k_tiles_w1 = math.ceil(k / block_k) + n_tiles_w2 = math.ceil(k / block_n) + k_tiles_w2 = math.ceil(n / block_k) + + w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) + w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) + w1_s = torch.empty(e, + n_tiles_w1, + k_tiles_w1, + device="cuda", + dtype=torch.float32) + w2_s = torch.empty(e, + n_tiles_w2, + k_tiles_w2, + device="cuda", + dtype=torch.float32) + + for i in range(e): + w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) + w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + + return w1, w2, w1_s, w2_s + + +def run_single_case(m, n, k, topk, num_experts, block_size): + """ + Run one (M,N,K) configuration on a single GPU and assert DeepGEMM == + Triton baseline within tolerance. + """ + tokens_bf16 = torch.randn( + m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1) + _, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1]) + + # expert weight tensors + w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, + block_size) + + router_logits = torch.randn(m, + num_experts, + device="cuda", + dtype=torch.float32) + topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1) + topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) + + # triton referrence + out_triton = fused_experts( + hidden_states=tokens_bf16, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + block_shape=block_size, + allow_deep_gemm=False, + ) + + # DeepGemm + out_deepgemm = fused_experts( + hidden_states=tokens_bf16, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + block_shape=block_size, + allow_deep_gemm=True, + ) + + # ----- Compare ----- + rel_diff = (torch.mean( + torch.abs( + out_deepgemm.to(torch.float32) - out_triton.to(torch.float32))) / + torch.mean(torch.abs(out_triton.to(torch.float32)))) + + assert rel_diff < 0.005, f'{m=}, {k=}, {n=}, {rel_diff:.5f}' + + diff = calc_diff(out_deepgemm, out_triton) + assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' + + +MNKs = [ + (1024, 512, 128), + (1024, 512, 512), + (2048, 512, 512), + (512, 1024, 1024), + (512, 2048, 2048), + (4096, 4096, 1024), +] + +TOPKS = [2, 6] +NUM_EXPERTS = [32] + + +@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize("topk", TOPKS) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@requires_deep_gemm +def test_deepgemm_vs_triton(mnk, topk, num_experts): + import os + os.environ['VLLM_USE_DEEP_GEMM'] = "1" + torch.manual_seed(7) + m, n, k = mnk + + if topk > num_experts: + pytest.skip(f"topk={topk} > num_experts={num_experts}") + + run_single_case( + m=m, + n=n, + k=k, + topk=topk, + num_experts=num_experts, + block_size=BLOCK_SIZE, + ) + + +if __name__ == "__main__": + run_single_case( + m=1024, + n=1024, + k=512, + topk=2, + num_experts=32, + block_size=BLOCK_SIZE, + ) + print("DeepGEMM standalone test passed ✅") From 8e2594d09974e4d1be8dfc233b7a79b49995da24 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Wed, 25 Jun 2025 14:25:54 -0400 Subject: [PATCH 2/8] Update tests/kernels/moe/test_deepgemm.py Update through gemini's suggestion Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: yewentao256 --- tests/kernels/moe/test_deepgemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index b03fc1c96e69..d068e87a0555 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -161,7 +161,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size): out_deepgemm.to(torch.float32) - out_triton.to(torch.float32))) / torch.mean(torch.abs(out_triton.to(torch.float32)))) - assert rel_diff < 0.005, f'{m=}, {k=}, {n=}, {rel_diff:.5f}' + assert rel_diff < 0.005, f'Relative difference exceeds tolerance: {rel_diff:.5f} (m={m}, k={k}, n={n})' diff = calc_diff(out_deepgemm, out_triton) assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' From d2f4905240438b8e0c29d7aad99ac7da103a2e4a Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Wed, 25 Jun 2025 14:26:02 -0400 Subject: [PATCH 3/8] Update tests/kernels/moe/test_deepgemm.py Update through gemini's suggestion Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: yewentao256 --- tests/kernels/moe/test_deepgemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index d068e87a0555..c5b419ee5b07 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -164,7 +164,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size): assert rel_diff < 0.005, f'Relative difference exceeds tolerance: {rel_diff:.5f} (m={m}, k={k}, n={n})' diff = calc_diff(out_deepgemm, out_triton) - assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' + assert diff < 0.001, f'Difference exceeds tolerance: {diff:.5f} (m={m}, k={k}, n={n})' MNKs = [ From ddce6b87d6492bc3f54d19096eaf27f13d07a699 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 25 Jun 2025 18:43:54 +0000 Subject: [PATCH 4/8] fix precommit issue Signed-off-by: yewentao256 --- tests/kernels/moe/test_deepgemm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index c5b419ee5b07..c88be6d1ba52 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -161,10 +161,11 @@ def run_single_case(m, n, k, topk, num_experts, block_size): out_deepgemm.to(torch.float32) - out_triton.to(torch.float32))) / torch.mean(torch.abs(out_triton.to(torch.float32)))) - assert rel_diff < 0.005, f'Relative difference exceeds tolerance: {rel_diff:.5f} (m={m}, k={k}, n={n})' + assert rel_diff < 0.005, \ + f'Relative error: {rel_diff:.5f} (m={m}, k={k}, n={n})' diff = calc_diff(out_deepgemm, out_triton) - assert diff < 0.001, f'Difference exceeds tolerance: {diff:.5f} (m={m}, k={k}, n={n})' + assert diff < 0.001, f'Dice error: {diff:.5f} (m={m}, k={k}, n={n})' MNKs = [ From 6a9f3c1064f77a844e5db6004c2727056f35b11d Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Thu, 26 Jun 2025 15:29:18 +0000 Subject: [PATCH 5/8] remove the temporary code for __main__ Signed-off-by: yewentao256 --- tests/kernels/moe/test_deepgemm.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index c88be6d1ba52..d51efc2a4a33 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -202,15 +202,3 @@ def test_deepgemm_vs_triton(mnk, topk, num_experts): num_experts=num_experts, block_size=BLOCK_SIZE, ) - - -if __name__ == "__main__": - run_single_case( - m=1024, - n=1024, - k=512, - topk=2, - num_experts=32, - block_size=BLOCK_SIZE, - ) - print("DeepGEMM standalone test passed ✅") From 2fb3ec777c22e549de70d82804bb2a91e6d3bec8 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 27 Jun 2025 19:47:45 +0000 Subject: [PATCH 6/8] assert call for deepgemm Signed-off-by: yewentao256 --- tests/kernels/moe/test_deepgemm.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index d51efc2a4a33..f69d9b7ffdb7 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -6,6 +6,7 @@ import importlib import math +import os import pytest import torch @@ -168,6 +169,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size): assert diff < 0.001, f'Dice error: {diff:.5f} (m={m}, k={k}, n={n})' +# Note: W1 has shape (E, 2N, K), so N = 512 +# can trigger the deepgemm path. MNKs = [ (1024, 512, 128), (1024, 512, 512), @@ -185,10 +188,24 @@ def run_single_case(m, n, k, topk, num_experts, block_size): @pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @requires_deep_gemm -def test_deepgemm_vs_triton(mnk, topk, num_experts): - import os - os.environ['VLLM_USE_DEEP_GEMM'] = "1" - torch.manual_seed(7) +def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): + + os.environ["VLLM_USE_DEEP_GEMM"] = "1" + + _fused_moe_mod = importlib.import_module( + "vllm.model_executor.layers.fused_moe.fused_moe") + + call_counter = {"cnt": 0} + + orig_fn = _fused_moe_mod.deep_gemm_moe_fp8 + + def _spy_deep_gemm_moe_fp8(*args, **kwargs): + call_counter["cnt"] += 1 + return orig_fn(*args, **kwargs) + + monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", + _spy_deep_gemm_moe_fp8) + m, n, k = mnk if topk > num_experts: @@ -202,3 +219,8 @@ def test_deepgemm_vs_triton(mnk, topk, num_experts): num_experts=num_experts, block_size=BLOCK_SIZE, ) + + # ensure that the DeepGEMM path was indeed taken. + assert call_counter["cnt"] == 1, \ + f"DeepGEMM path was not executed during the test. " \ + f"Call counter: {call_counter['cnt']}" From 4c125c2353df16456769ec25508babbf40f1fd7b Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 27 Jun 2025 20:05:23 +0000 Subject: [PATCH 7/8] use torch.testing.assert_close Signed-off-by: yewentao256 --- tests/kernels/moe/test_deepgemm.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index f69d9b7ffdb7..3e115ba4f376 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -156,17 +156,16 @@ def run_single_case(m, n, k, topk, num_experts, block_size): allow_deep_gemm=True, ) + base = out_triton.abs().mean() + atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3 + rtol = 0.05 # ----- Compare ----- - rel_diff = (torch.mean( - torch.abs( - out_deepgemm.to(torch.float32) - out_triton.to(torch.float32))) / - torch.mean(torch.abs(out_triton.to(torch.float32)))) - - assert rel_diff < 0.005, \ - f'Relative error: {rel_diff:.5f} (m={m}, k={k}, n={n})' - - diff = calc_diff(out_deepgemm, out_triton) - assert diff < 0.001, f'Dice error: {diff:.5f} (m={m}, k={k}, n={n})' + torch.testing.assert_close( + out_deepgemm.to(torch.float32), + out_triton.to(torch.float32), + rtol=rtol, + atol=float(atol), + ) # Note: W1 has shape (E, 2N, K), so N = 512 From c2a1e145ab43be2683f8baaa8dc0c8ebe29395af Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 27 Jun 2025 21:07:53 +0000 Subject: [PATCH 8/8] fix monkeypatch Signed-off-by: yewentao256 --- tests/kernels/moe/test_deepgemm.py | 52 +++++++++++++++--------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index 3e115ba4f376..5d2690904cea 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -6,7 +6,6 @@ import importlib import math -import os import pytest import torch @@ -189,37 +188,38 @@ def run_single_case(m, n, k, topk, num_experts, block_size): @requires_deep_gemm def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): - os.environ["VLLM_USE_DEEP_GEMM"] = "1" + with monkeypatch.context() as m: + m.setenv("VLLM_USE_DEEP_GEMM", "1") - _fused_moe_mod = importlib.import_module( - "vllm.model_executor.layers.fused_moe.fused_moe") + _fused_moe_mod = importlib.import_module( + "vllm.model_executor.layers.fused_moe.fused_moe") - call_counter = {"cnt": 0} + call_counter = {"cnt": 0} - orig_fn = _fused_moe_mod.deep_gemm_moe_fp8 + orig_fn = _fused_moe_mod.deep_gemm_moe_fp8 - def _spy_deep_gemm_moe_fp8(*args, **kwargs): - call_counter["cnt"] += 1 - return orig_fn(*args, **kwargs) + def _spy_deep_gemm_moe_fp8(*args, **kwargs): + call_counter["cnt"] += 1 + return orig_fn(*args, **kwargs) - monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", - _spy_deep_gemm_moe_fp8) + monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", + _spy_deep_gemm_moe_fp8) - m, n, k = mnk + m, n, k = mnk - if topk > num_experts: - pytest.skip(f"topk={topk} > num_experts={num_experts}") + if topk > num_experts: + pytest.skip(f"topk={topk} > num_experts={num_experts}") - run_single_case( - m=m, - n=n, - k=k, - topk=topk, - num_experts=num_experts, - block_size=BLOCK_SIZE, - ) + run_single_case( + m=m, + n=n, + k=k, + topk=topk, + num_experts=num_experts, + block_size=BLOCK_SIZE, + ) - # ensure that the DeepGEMM path was indeed taken. - assert call_counter["cnt"] == 1, \ - f"DeepGEMM path was not executed during the test. " \ - f"Call counter: {call_counter['cnt']}" + # ensure that the DeepGEMM path was indeed taken. + assert call_counter["cnt"] == 1, \ + f"DeepGEMM path was not executed during the test. " \ + f"Call counter: {call_counter['cnt']}"