Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ steps:
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
# we can only upgrade after this is resolved
- pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/

- label: LM Eval Small Models # 53min
timeout_in_minutes: 75
Expand Down Expand Up @@ -828,6 +828,23 @@ steps:
- uv pip install --system 'gpt-oss[eval]==0.0.5'
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 --server-args '--tensor-parallel-size 2'

- label: Blackwell Quantized MoE Test
timeout_in_minutes: 60
working_dir: "/vllm-workspace/"
gpu: b200
source_file_dependencies:
- tests/quantization/test_blackwell_moe.py
- vllm/model_executor/models/deepseek_v2.py
- vllm/model_executor/models/gpt_oss.py
- vllm/model_executor/models/llama4.py
- vllm/model_executor/layers/fused_moe
- vllm/model_executor/layers/quantization/compressed_tensors
- vllm/model_executor/layers/quantization/modelopt.py
- vllm/model_executor/layers/quantization/mxfp4.py
- vllm/v1/attention/backends/flashinfer.py
commands:
- pytest -s -v tests/quantization/test_blackwell_moe.py

##### 1 GPU test #####
##### multi gpus test #####

Expand Down
132 changes: 132 additions & 0 deletions tests/quantization/test_blackwell_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
import os

import pytest

from tests.utils import RemoteOpenAIServer
from vllm.platforms import current_platform

if not current_platform.is_device_capability(100):
pytest.skip("This test only runs on Blackwell GPUs (SM100).",
allow_module_level=True)

os.environ["FLASHINFER_NVCC_THREADS"] = "16"

# dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4,
# "text_config": {"num_layers": 4, "num_hidden_layers": 4}}
dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4}


def can_initialize(model: str, extra_args: list[str]):

# Server arguments
server_args = [
"--max-model-len",
"2048",
"--max-num-batched-tokens",
"256",
"--load-format",
"dummy",
"--trust-remote-code",
"--limit-mm-per-prompt",
json.dumps({"image": 0}),
*extra_args,
]

# Launch server and make a simple request
with RemoteOpenAIServer(
model,
server_args,
max_wait_seconds=1000, # Due to FlashInfer compile
override_hf_configs=dummy_hf_overrides) as server:
client = server.get_client()
# Make a simple request to verify the server works
completion = client.completions.create(
model=model,
prompt=["Hello, World!"],
temperature=0,
max_tokens=2,
)
print(completion)
assert completion.choices[0].text is not None


## Llama4 ##


@pytest.mark.skip(reason=(
"RuntimeError: run_moe() Expected a value of type "
"'Optional[List[Tensor]]' for argument '_9' but instead found type "
"'list'."))
def test_llama4_fp8_tensor_moe_flashinfer_cutlass(
monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", [])


@pytest.mark.skip(reason="Works, but takes too long to run")
def test_llama4_fp8_tensor_moe_flashinfer_trtllm(
monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", [])


@pytest.mark.skip(reason="Works, but takes too long to run")
def test_llama4_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", [])


@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options")
def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", [])


## DeepSeekV3 ##


def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")
can_initialize("deepseek-ai/DeepSeek-V3.1", [])


def test_deepseek_nvfp4_moe_flashinfer_cutlass(
monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", [])


@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options")
def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", [])


## GPT-OSS ##


def test_gptoss_mxfp4bf16_moe_flashinfer(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1")
can_initialize("openai/gpt-oss-20b", [])


def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass(
monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "1")
can_initialize("openai/gpt-oss-20b", [])


def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(
monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
can_initialize("openai/gpt-oss-20b", [])
4 changes: 3 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ def _start_server(self, model: str, vllm_serve_args: list[str],
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if env_dict is not None:
env.update(env_dict)
serve_cmd = ["vllm", "serve", model, *vllm_serve_args]
print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}")
self.proc: subprocess.Popen = subprocess.Popen(
["vllm", "serve", model, *vllm_serve_args],
serve_cmd,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def flashinfer_fused_moe_blockscale_fp8(
assert global_num_experts % 4 == 0
assert top_k < (topk_group * global_num_experts / num_expert_group)
assert block_shape == [128, 128]
# Routing kernel expects #experts <= #threads 256
assert global_num_experts <= 256

a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
Expand Down