|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +""" |
| 3 | +Unit-test DeepGEMM FP8 kernels (no DeepEP). |
| 4 | +Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts. |
| 5 | +""" |
| 6 | + |
| 7 | +import importlib |
| 8 | +import math |
| 9 | + |
| 10 | +import pytest |
| 11 | +import torch |
| 12 | + |
| 13 | +# vLLM fused-expert reference (Triton fallback + DeepGEMM option) |
| 14 | +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts |
| 15 | +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( |
| 16 | + per_token_group_quant_fp8) |
| 17 | +from vllm.utils import cdiv |
| 18 | + |
| 19 | +has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None |
| 20 | + |
| 21 | +if has_deep_gemm: |
| 22 | + import deep_gemm |
| 23 | + BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout() |
| 24 | + BLOCK_SIZE = [BLOCK_M, BLOCK_M] |
| 25 | + |
| 26 | +requires_deep_gemm = pytest.mark.skipif( |
| 27 | + not has_deep_gemm, |
| 28 | + reason="Requires deep_gemm kernels", |
| 29 | +) |
| 30 | + |
| 31 | + |
| 32 | +def calc_diff(x: torch.Tensor, y: torch.Tensor): |
| 33 | + x, y = x.double(), y.double() |
| 34 | + denominator = (x * x + y * y).sum() |
| 35 | + sim = 2 * (x * y).sum() / denominator |
| 36 | + return 1 - sim |
| 37 | + |
| 38 | + |
| 39 | +def per_block_cast_to_fp8( |
| 40 | + x: torch.Tensor, |
| 41 | + block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: |
| 42 | + assert x.dim() == 2 |
| 43 | + m, n = x.shape |
| 44 | + x_padded = torch.zeros( |
| 45 | + (cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n), |
| 46 | + dtype=x.dtype, |
| 47 | + device=x.device) |
| 48 | + x_padded[:m, :n] = x |
| 49 | + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) |
| 50 | + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) |
| 51 | + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) |
| 52 | + x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() |
| 53 | + scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) |
| 54 | + return x_scaled_sub, scales |
| 55 | + |
| 56 | + |
| 57 | +def make_block_quant_fp8_weights( |
| 58 | + e: int, |
| 59 | + n: int, |
| 60 | + k: int, |
| 61 | + block_size: list[int], |
| 62 | +): |
| 63 | + """ |
| 64 | + Generate (w1, w2) expert weights and their per-block scale tensors |
| 65 | + in FP8 block-quantized format. |
| 66 | +
|
| 67 | + w1 shape: (E, 2N, K) |
| 68 | + w2 shape: (E, K, N) |
| 69 | + """ |
| 70 | + dtype = torch.bfloat16 |
| 71 | + fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo( |
| 72 | + torch.float8_e4m3fn).min |
| 73 | + |
| 74 | + # bf16 reference weights |
| 75 | + w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10 |
| 76 | + w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) / 10 |
| 77 | + w1_bf16.clamp_(fp8_min, fp8_max) |
| 78 | + w2_bf16.clamp_(fp8_min, fp8_max) |
| 79 | + |
| 80 | + block_n, block_k = block_size |
| 81 | + n_tiles_w1 = math.ceil((2 * n) / block_n) |
| 82 | + k_tiles_w1 = math.ceil(k / block_k) |
| 83 | + n_tiles_w2 = math.ceil(k / block_n) |
| 84 | + k_tiles_w2 = math.ceil(n / block_k) |
| 85 | + |
| 86 | + w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) |
| 87 | + w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) |
| 88 | + w1_s = torch.empty(e, |
| 89 | + n_tiles_w1, |
| 90 | + k_tiles_w1, |
| 91 | + device="cuda", |
| 92 | + dtype=torch.float32) |
| 93 | + w2_s = torch.empty(e, |
| 94 | + n_tiles_w2, |
| 95 | + k_tiles_w2, |
| 96 | + device="cuda", |
| 97 | + dtype=torch.float32) |
| 98 | + |
| 99 | + for i in range(e): |
| 100 | + w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) |
| 101 | + w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) |
| 102 | + |
| 103 | + return w1, w2, w1_s, w2_s |
| 104 | + |
| 105 | + |
| 106 | +def run_single_case(m, n, k, topk, num_experts, block_size): |
| 107 | + """ |
| 108 | + Run one (M,N,K) configuration on a single GPU and assert DeepGEMM == |
| 109 | + Triton baseline within tolerance. |
| 110 | + """ |
| 111 | + tokens_bf16 = torch.randn( |
| 112 | + m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1) |
| 113 | + _, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1]) |
| 114 | + |
| 115 | + # expert weight tensors |
| 116 | + w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, |
| 117 | + block_size) |
| 118 | + |
| 119 | + router_logits = torch.randn(m, |
| 120 | + num_experts, |
| 121 | + device="cuda", |
| 122 | + dtype=torch.float32) |
| 123 | + topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1) |
| 124 | + topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) |
| 125 | + |
| 126 | + # triton referrence |
| 127 | + out_triton = fused_experts( |
| 128 | + hidden_states=tokens_bf16, |
| 129 | + w1=w1, |
| 130 | + w2=w2, |
| 131 | + topk_weights=topk_weights, |
| 132 | + topk_ids=topk_ids, |
| 133 | + inplace=False, |
| 134 | + use_fp8_w8a8=True, |
| 135 | + w1_scale=w1_s, |
| 136 | + w2_scale=w2_s, |
| 137 | + a1_scale=a1_scale, |
| 138 | + block_shape=block_size, |
| 139 | + allow_deep_gemm=False, |
| 140 | + ) |
| 141 | + |
| 142 | + # DeepGemm |
| 143 | + out_deepgemm = fused_experts( |
| 144 | + hidden_states=tokens_bf16, |
| 145 | + w1=w1, |
| 146 | + w2=w2, |
| 147 | + topk_weights=topk_weights, |
| 148 | + topk_ids=topk_ids, |
| 149 | + inplace=False, |
| 150 | + use_fp8_w8a8=True, |
| 151 | + w1_scale=w1_s, |
| 152 | + w2_scale=w2_s, |
| 153 | + a1_scale=a1_scale, |
| 154 | + block_shape=block_size, |
| 155 | + allow_deep_gemm=True, |
| 156 | + ) |
| 157 | + |
| 158 | + base = out_triton.abs().mean() |
| 159 | + atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3 |
| 160 | + rtol = 0.05 |
| 161 | + # ----- Compare ----- |
| 162 | + torch.testing.assert_close( |
| 163 | + out_deepgemm.to(torch.float32), |
| 164 | + out_triton.to(torch.float32), |
| 165 | + rtol=rtol, |
| 166 | + atol=float(atol), |
| 167 | + ) |
| 168 | + |
| 169 | + |
| 170 | +# Note: W1 has shape (E, 2N, K), so N = 512 |
| 171 | +# can trigger the deepgemm path. |
| 172 | +MNKs = [ |
| 173 | + (1024, 512, 128), |
| 174 | + (1024, 512, 512), |
| 175 | + (2048, 512, 512), |
| 176 | + (512, 1024, 1024), |
| 177 | + (512, 2048, 2048), |
| 178 | + (4096, 4096, 1024), |
| 179 | +] |
| 180 | + |
| 181 | +TOPKS = [2, 6] |
| 182 | +NUM_EXPERTS = [32] |
| 183 | + |
| 184 | + |
| 185 | +@pytest.mark.parametrize("mnk", MNKs) |
| 186 | +@pytest.mark.parametrize("topk", TOPKS) |
| 187 | +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) |
| 188 | +@requires_deep_gemm |
| 189 | +def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): |
| 190 | + |
| 191 | + with monkeypatch.context() as m: |
| 192 | + m.setenv("VLLM_USE_DEEP_GEMM", "1") |
| 193 | + |
| 194 | + _fused_moe_mod = importlib.import_module( |
| 195 | + "vllm.model_executor.layers.fused_moe.fused_moe") |
| 196 | + |
| 197 | + call_counter = {"cnt": 0} |
| 198 | + |
| 199 | + orig_fn = _fused_moe_mod.deep_gemm_moe_fp8 |
| 200 | + |
| 201 | + def _spy_deep_gemm_moe_fp8(*args, **kwargs): |
| 202 | + call_counter["cnt"] += 1 |
| 203 | + return orig_fn(*args, **kwargs) |
| 204 | + |
| 205 | + monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", |
| 206 | + _spy_deep_gemm_moe_fp8) |
| 207 | + |
| 208 | + m, n, k = mnk |
| 209 | + |
| 210 | + if topk > num_experts: |
| 211 | + pytest.skip(f"topk={topk} > num_experts={num_experts}") |
| 212 | + |
| 213 | + run_single_case( |
| 214 | + m=m, |
| 215 | + n=n, |
| 216 | + k=k, |
| 217 | + topk=topk, |
| 218 | + num_experts=num_experts, |
| 219 | + block_size=BLOCK_SIZE, |
| 220 | + ) |
| 221 | + |
| 222 | + # ensure that the DeepGEMM path was indeed taken. |
| 223 | + assert call_counter["cnt"] == 1, \ |
| 224 | + f"DeepGEMM path was not executed during the test. " \ |
| 225 | + f"Call counter: {call_counter['cnt']}" |
0 commit comments