diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7a1f38606062..5cdd2c388f37 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -520,7 +520,7 @@ steps: # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now # we can only upgrade after this is resolved - pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128 - - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization + - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ - label: LM Eval Small Models # 53min timeout_in_minutes: 75 @@ -828,6 +828,23 @@ steps: - uv pip install --system 'gpt-oss[eval]==0.0.5' - pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 --server-args '--tensor-parallel-size 2' +- label: Blackwell Quantized MoE Test + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/" + gpu: b200 + source_file_dependencies: + - tests/quantization/test_blackwell_moe.py + - vllm/model_executor/models/deepseek_v2.py + - vllm/model_executor/models/gpt_oss.py + - vllm/model_executor/models/llama4.py + - vllm/model_executor/layers/fused_moe + - vllm/model_executor/layers/quantization/compressed_tensors + - vllm/model_executor/layers/quantization/modelopt.py + - vllm/model_executor/layers/quantization/mxfp4.py + - vllm/v1/attention/backends/flashinfer.py + commands: + - pytest -s -v tests/quantization/test_blackwell_moe.py + ##### 1 GPU test ##### ##### multi gpus test ##### diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py new file mode 100644 index 000000000000..c021126720af --- /dev/null +++ b/tests/quantization/test_blackwell_moe.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import os + +import pytest + +from tests.utils import RemoteOpenAIServer +from vllm.platforms import current_platform + +if not current_platform.is_device_capability(100): + pytest.skip("This test only runs on Blackwell GPUs (SM100).", + allow_module_level=True) + +os.environ["FLASHINFER_NVCC_THREADS"] = "16" + +# dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4, +# "text_config": {"num_layers": 4, "num_hidden_layers": 4}} +dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4} + + +def can_initialize(model: str, extra_args: list[str]): + + # Server arguments + server_args = [ + "--max-model-len", + "2048", + "--max-num-batched-tokens", + "256", + "--load-format", + "dummy", + "--trust-remote-code", + "--limit-mm-per-prompt", + json.dumps({"image": 0}), + *extra_args, + ] + + # Launch server and make a simple request + with RemoteOpenAIServer( + model, + server_args, + max_wait_seconds=1000, # Due to FlashInfer compile + override_hf_configs=dummy_hf_overrides) as server: + client = server.get_client() + # Make a simple request to verify the server works + completion = client.completions.create( + model=model, + prompt=["Hello, World!"], + temperature=0, + max_tokens=2, + ) + print(completion) + assert completion.choices[0].text is not None + + +## Llama4 ## + + +@pytest.mark.skip(reason=( + "RuntimeError: run_moe() Expected a value of type " + "'Optional[List[Tensor]]' for argument '_9' but instead found type " + "'list'.")) +def test_llama4_fp8_tensor_moe_flashinfer_cutlass( + monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") + can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", []) + + +@pytest.mark.skip(reason="Works, but takes too long to run") +def test_llama4_fp8_tensor_moe_flashinfer_trtllm( + monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") + can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", []) + + +@pytest.mark.skip(reason="Works, but takes too long to run") +def test_llama4_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") + can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", []) + + +@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options") +def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") + can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", []) + + +## DeepSeekV3 ## + + +def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1") + can_initialize("deepseek-ai/DeepSeek-V3.1", []) + + +def test_deepseek_nvfp4_moe_flashinfer_cutlass( + monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") + can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", []) + + +@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options") +def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") + can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", []) + + +## GPT-OSS ## + + +def test_gptoss_mxfp4bf16_moe_flashinfer(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1") + can_initialize("openai/gpt-oss-20b", []) + + +def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass( + monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "1") + can_initialize("openai/gpt-oss-20b", []) + + +def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm( + monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1") + can_initialize("openai/gpt-oss-20b", []) diff --git a/tests/utils.py b/tests/utils.py index f630c57f46d8..ab6ccc7ad9f9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -91,8 +91,10 @@ def _start_server(self, model: str, vllm_serve_args: list[str], env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' if env_dict is not None: env.update(env_dict) + serve_cmd = ["vllm", "serve", model, *vllm_serve_args] + print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}") self.proc: subprocess.Popen = subprocess.Popen( - ["vllm", "serve", model, *vllm_serve_args], + serve_cmd, env=env, stdout=sys.stdout, stderr=sys.stderr, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index fe586a22e250..74bcffd8ca03 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -40,6 +40,8 @@ def flashinfer_fused_moe_blockscale_fp8( assert global_num_experts % 4 == 0 assert top_k < (topk_group * global_num_experts / num_expert_group) assert block_shape == [128, 128] + # Routing kernel expects #experts <= #threads 256 + assert global_num_experts <= 256 a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) # NOTE: scales of hidden states have to be transposed!