Skip to content

Commit 551ef16

Browse files
[Unit Test] Add unit test for deep gemm (#20090)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 2863bef commit 551ef16

File tree

1 file changed

+225
-0
lines changed

1 file changed

+225
-0
lines changed

tests/kernels/moe/test_deepgemm.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Unit-test DeepGEMM FP8 kernels (no DeepEP).
4+
Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts.
5+
"""
6+
7+
import importlib
8+
import math
9+
10+
import pytest
11+
import torch
12+
13+
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
14+
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
15+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
16+
per_token_group_quant_fp8)
17+
from vllm.utils import cdiv
18+
19+
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
20+
21+
if has_deep_gemm:
22+
import deep_gemm
23+
BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout()
24+
BLOCK_SIZE = [BLOCK_M, BLOCK_M]
25+
26+
requires_deep_gemm = pytest.mark.skipif(
27+
not has_deep_gemm,
28+
reason="Requires deep_gemm kernels",
29+
)
30+
31+
32+
def calc_diff(x: torch.Tensor, y: torch.Tensor):
33+
x, y = x.double(), y.double()
34+
denominator = (x * x + y * y).sum()
35+
sim = 2 * (x * y).sum() / denominator
36+
return 1 - sim
37+
38+
39+
def per_block_cast_to_fp8(
40+
x: torch.Tensor,
41+
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
42+
assert x.dim() == 2
43+
m, n = x.shape
44+
x_padded = torch.zeros(
45+
(cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n),
46+
dtype=x.dtype,
47+
device=x.device)
48+
x_padded[:m, :n] = x
49+
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
50+
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
51+
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
52+
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
53+
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
54+
return x_scaled_sub, scales
55+
56+
57+
def make_block_quant_fp8_weights(
58+
e: int,
59+
n: int,
60+
k: int,
61+
block_size: list[int],
62+
):
63+
"""
64+
Generate (w1, w2) expert weights and their per-block scale tensors
65+
in FP8 block-quantized format.
66+
67+
w1 shape: (E, 2N, K)
68+
w2 shape: (E, K, N)
69+
"""
70+
dtype = torch.bfloat16
71+
fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo(
72+
torch.float8_e4m3fn).min
73+
74+
# bf16 reference weights
75+
w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10
76+
w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) / 10
77+
w1_bf16.clamp_(fp8_min, fp8_max)
78+
w2_bf16.clamp_(fp8_min, fp8_max)
79+
80+
block_n, block_k = block_size
81+
n_tiles_w1 = math.ceil((2 * n) / block_n)
82+
k_tiles_w1 = math.ceil(k / block_k)
83+
n_tiles_w2 = math.ceil(k / block_n)
84+
k_tiles_w2 = math.ceil(n / block_k)
85+
86+
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
87+
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
88+
w1_s = torch.empty(e,
89+
n_tiles_w1,
90+
k_tiles_w1,
91+
device="cuda",
92+
dtype=torch.float32)
93+
w2_s = torch.empty(e,
94+
n_tiles_w2,
95+
k_tiles_w2,
96+
device="cuda",
97+
dtype=torch.float32)
98+
99+
for i in range(e):
100+
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
101+
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
102+
103+
return w1, w2, w1_s, w2_s
104+
105+
106+
def run_single_case(m, n, k, topk, num_experts, block_size):
107+
"""
108+
Run one (M,N,K) configuration on a single GPU and assert DeepGEMM ==
109+
Triton baseline within tolerance.
110+
"""
111+
tokens_bf16 = torch.randn(
112+
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
113+
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
114+
115+
# expert weight tensors
116+
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
117+
block_size)
118+
119+
router_logits = torch.randn(m,
120+
num_experts,
121+
device="cuda",
122+
dtype=torch.float32)
123+
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
124+
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
125+
126+
# triton referrence
127+
out_triton = fused_experts(
128+
hidden_states=tokens_bf16,
129+
w1=w1,
130+
w2=w2,
131+
topk_weights=topk_weights,
132+
topk_ids=topk_ids,
133+
inplace=False,
134+
use_fp8_w8a8=True,
135+
w1_scale=w1_s,
136+
w2_scale=w2_s,
137+
a1_scale=a1_scale,
138+
block_shape=block_size,
139+
allow_deep_gemm=False,
140+
)
141+
142+
# DeepGemm
143+
out_deepgemm = fused_experts(
144+
hidden_states=tokens_bf16,
145+
w1=w1,
146+
w2=w2,
147+
topk_weights=topk_weights,
148+
topk_ids=topk_ids,
149+
inplace=False,
150+
use_fp8_w8a8=True,
151+
w1_scale=w1_s,
152+
w2_scale=w2_s,
153+
a1_scale=a1_scale,
154+
block_shape=block_size,
155+
allow_deep_gemm=True,
156+
)
157+
158+
base = out_triton.abs().mean()
159+
atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3
160+
rtol = 0.05
161+
# ----- Compare -----
162+
torch.testing.assert_close(
163+
out_deepgemm.to(torch.float32),
164+
out_triton.to(torch.float32),
165+
rtol=rtol,
166+
atol=float(atol),
167+
)
168+
169+
170+
# Note: W1 has shape (E, 2N, K), so N = 512
171+
# can trigger the deepgemm path.
172+
MNKs = [
173+
(1024, 512, 128),
174+
(1024, 512, 512),
175+
(2048, 512, 512),
176+
(512, 1024, 1024),
177+
(512, 2048, 2048),
178+
(4096, 4096, 1024),
179+
]
180+
181+
TOPKS = [2, 6]
182+
NUM_EXPERTS = [32]
183+
184+
185+
@pytest.mark.parametrize("mnk", MNKs)
186+
@pytest.mark.parametrize("topk", TOPKS)
187+
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
188+
@requires_deep_gemm
189+
def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
190+
191+
with monkeypatch.context() as m:
192+
m.setenv("VLLM_USE_DEEP_GEMM", "1")
193+
194+
_fused_moe_mod = importlib.import_module(
195+
"vllm.model_executor.layers.fused_moe.fused_moe")
196+
197+
call_counter = {"cnt": 0}
198+
199+
orig_fn = _fused_moe_mod.deep_gemm_moe_fp8
200+
201+
def _spy_deep_gemm_moe_fp8(*args, **kwargs):
202+
call_counter["cnt"] += 1
203+
return orig_fn(*args, **kwargs)
204+
205+
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
206+
_spy_deep_gemm_moe_fp8)
207+
208+
m, n, k = mnk
209+
210+
if topk > num_experts:
211+
pytest.skip(f"topk={topk} > num_experts={num_experts}")
212+
213+
run_single_case(
214+
m=m,
215+
n=n,
216+
k=k,
217+
topk=topk,
218+
num_experts=num_experts,
219+
block_size=BLOCK_SIZE,
220+
)
221+
222+
# ensure that the DeepGEMM path was indeed taken.
223+
assert call_counter["cnt"] == 1, \
224+
f"DeepGEMM path was not executed during the test. " \
225+
f"Call counter: {call_counter['cnt']}"

0 commit comments

Comments
 (0)