Skip to content

Commit b6f16d3

Browse files
mgoinyewentao256
authored andcommitted
[CI] Add E2E Blackwell Quantized MoE Test (#25723)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 5157781 commit b6f16d3

File tree

4 files changed

+155
-2
lines changed

4 files changed

+155
-2
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ steps:
522522
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
523523
# we can only upgrade after this is resolved
524524
- pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
525-
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
525+
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/
526526

527527
- label: LM Eval Small Models # 53min
528528
timeout_in_minutes: 75
@@ -830,6 +830,23 @@ steps:
830830
- uv pip install --system 'gpt-oss[eval]==0.0.5'
831831
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 --server-args '--tensor-parallel-size 2'
832832

833+
- label: Blackwell Quantized MoE Test
834+
timeout_in_minutes: 60
835+
working_dir: "/vllm-workspace/"
836+
gpu: b200
837+
source_file_dependencies:
838+
- tests/quantization/test_blackwell_moe.py
839+
- vllm/model_executor/models/deepseek_v2.py
840+
- vllm/model_executor/models/gpt_oss.py
841+
- vllm/model_executor/models/llama4.py
842+
- vllm/model_executor/layers/fused_moe
843+
- vllm/model_executor/layers/quantization/compressed_tensors
844+
- vllm/model_executor/layers/quantization/modelopt.py
845+
- vllm/model_executor/layers/quantization/mxfp4.py
846+
- vllm/v1/attention/backends/flashinfer.py
847+
commands:
848+
- pytest -s -v tests/quantization/test_blackwell_moe.py
849+
833850
##### 1 GPU test #####
834851
##### multi gpus test #####
835852

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import json
5+
import os
6+
7+
import pytest
8+
9+
from tests.utils import RemoteOpenAIServer
10+
from vllm.platforms import current_platform
11+
12+
if not current_platform.is_device_capability(100):
13+
pytest.skip("This test only runs on Blackwell GPUs (SM100).",
14+
allow_module_level=True)
15+
16+
os.environ["FLASHINFER_NVCC_THREADS"] = "16"
17+
18+
# dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4,
19+
# "text_config": {"num_layers": 4, "num_hidden_layers": 4}}
20+
dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4}
21+
22+
23+
def can_initialize(model: str, extra_args: list[str]):
24+
25+
# Server arguments
26+
server_args = [
27+
"--max-model-len",
28+
"2048",
29+
"--max-num-batched-tokens",
30+
"256",
31+
"--load-format",
32+
"dummy",
33+
"--trust-remote-code",
34+
"--limit-mm-per-prompt",
35+
json.dumps({"image": 0}),
36+
*extra_args,
37+
]
38+
39+
# Launch server and make a simple request
40+
with RemoteOpenAIServer(
41+
model,
42+
server_args,
43+
max_wait_seconds=1000, # Due to FlashInfer compile
44+
override_hf_configs=dummy_hf_overrides) as server:
45+
client = server.get_client()
46+
# Make a simple request to verify the server works
47+
completion = client.completions.create(
48+
model=model,
49+
prompt=["Hello, World!"],
50+
temperature=0,
51+
max_tokens=2,
52+
)
53+
print(completion)
54+
assert completion.choices[0].text is not None
55+
56+
57+
## Llama4 ##
58+
59+
60+
@pytest.mark.skip(reason=(
61+
"RuntimeError: run_moe() Expected a value of type "
62+
"'Optional[List[Tensor]]' for argument '_9' but instead found type "
63+
"'list'."))
64+
def test_llama4_fp8_tensor_moe_flashinfer_cutlass(
65+
monkeypatch: pytest.MonkeyPatch):
66+
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1")
67+
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
68+
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", [])
69+
70+
71+
@pytest.mark.skip(reason="Works, but takes too long to run")
72+
def test_llama4_fp8_tensor_moe_flashinfer_trtllm(
73+
monkeypatch: pytest.MonkeyPatch):
74+
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1")
75+
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
76+
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", [])
77+
78+
79+
@pytest.mark.skip(reason="Works, but takes too long to run")
80+
def test_llama4_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
81+
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
82+
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
83+
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", [])
84+
85+
86+
@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options")
87+
def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
88+
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
89+
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
90+
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", [])
91+
92+
93+
## DeepSeekV3 ##
94+
95+
96+
def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch):
97+
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")
98+
can_initialize("deepseek-ai/DeepSeek-V3.1", [])
99+
100+
101+
def test_deepseek_nvfp4_moe_flashinfer_cutlass(
102+
monkeypatch: pytest.MonkeyPatch):
103+
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
104+
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
105+
can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", [])
106+
107+
108+
@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options")
109+
def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
110+
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
111+
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
112+
can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", [])
113+
114+
115+
## GPT-OSS ##
116+
117+
118+
def test_gptoss_mxfp4bf16_moe_flashinfer(monkeypatch: pytest.MonkeyPatch):
119+
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1")
120+
can_initialize("openai/gpt-oss-20b", [])
121+
122+
123+
def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass(
124+
monkeypatch: pytest.MonkeyPatch):
125+
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "1")
126+
can_initialize("openai/gpt-oss-20b", [])
127+
128+
129+
def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(
130+
monkeypatch: pytest.MonkeyPatch):
131+
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
132+
can_initialize("openai/gpt-oss-20b", [])

tests/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ def _start_server(self, model: str, vllm_serve_args: list[str],
9191
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
9292
if env_dict is not None:
9393
env.update(env_dict)
94+
serve_cmd = ["vllm", "serve", model, *vllm_serve_args]
95+
print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}")
9496
self.proc: subprocess.Popen = subprocess.Popen(
95-
["vllm", "serve", model, *vllm_serve_args],
97+
serve_cmd,
9698
env=env,
9799
stdout=sys.stdout,
98100
stderr=sys.stderr,

vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def flashinfer_fused_moe_blockscale_fp8(
4040
assert global_num_experts % 4 == 0
4141
assert top_k < (topk_group * global_num_experts / num_expert_group)
4242
assert block_shape == [128, 128]
43+
# Routing kernel expects #experts <= #threads 256
44+
assert global_num_experts <= 256
4345

4446
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
4547
# NOTE: scales of hidden states have to be transposed!

0 commit comments

Comments
 (0)