From b82ace97ce3bfb0e325fa07ddad8ff0b5fbedb01 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 25 Sep 2025 17:21:58 -0400 Subject: [PATCH 1/9] Add e2e model run for SM100 Quantized MoEs Signed-off-by: mgoin --- .buildkite/test-pipeline.yaml | 14 ++- tests/quantization/test_blackwell_moe.py | 89 +++++++++++++++++++ .../layers/fused_moe/flashinfer_trtllm_moe.py | 2 + 3 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 tests/quantization/test_blackwell_moe.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7a1f38606062..6ab9df2dac88 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -520,7 +520,7 @@ steps: # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now # we can only upgrade after this is resolved - pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128 - - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization + - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ - label: LM Eval Small Models # 53min timeout_in_minutes: 75 @@ -828,6 +828,18 @@ steps: - uv pip install --system 'gpt-oss[eval]==0.0.5' - pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 --server-args '--tensor-parallel-size 2' +- label: Blackwell Quantized MoE Test + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/" + gpu: b200 + source_file_dependencies: + - tests/evals/gpt_oss + - vllm/model_executor/models/gpt_oss.py + - vllm/model_executor/layers/quantization/mxfp4.py + - vllm/v1/attention/backends/flashinfer.py + commands: + - pytest -s -v tests/quantization/test_blackwell_moe.py + ##### 1 GPU test ##### ##### multi gpus test ##### diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py new file mode 100644 index 000000000000..01ec572051e8 --- /dev/null +++ b/tests/quantization/test_blackwell_moe.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from functools import partial +from unittest.mock import patch + +import pytest +from transformers import PretrainedConfig + +from vllm import LLM +from vllm.sampling_params import SamplingParams +from vllm.platforms import current_platform +from tests.utils import create_new_process_for_each_test + +if not current_platform.is_device_capability(100): + pytest.skip("This test only runs on Blackwell GPUs (SM100).", + allow_module_level=True) + +def dummy_hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: + """ + Dummy HF overrides function used to create dummy model + with only minimum nums of layer. + """ + text_config = hf_config.get_text_config() + + # Do 4 backbone layers to include dense and MoE layers + text_config.update({ + "num_layers": 4, + "num_hidden_layers": 4, + }) + + if hasattr(hf_config, "vision_config"): + hf_config.vision_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + }) + # e.g.: ibm-granite/granite-speech-3.3-2b + if hasattr(hf_config, "encoder_config"): + hf_config.encoder_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + }) + # e.g.: Qwen/Qwen2-Audio-7B-Instruct + if hasattr(hf_config, "audio_config"): + hf_config.audio_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + "encoder_layers": 1, + }) + + return hf_config + + +# @create_new_process_for_each_test() +def can_initialize(vllm_runner, model_name: str, **model_kwargs): + + default_model_kwargs = { + "enforce_eager": True, + "trust_remote_code": True, + "max_model_len": 1024, + "gpu_memory_utilization": 0.8, + "load_format": "dummy", + # "hf_overrides": dummy_hf_overrides, + } + default_model_kwargs.update(model_kwargs) + + with vllm_runner(model_name, **default_model_kwargs) as llm: + sp = SamplingParams(temperature=0, max_tokens=2) + llm.generate("Hello, world!", sampling_params=sp) + +def test_blackwell_fp8_tensor_moe_flashinfer_trtllm(vllm_runner, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") + can_initialize(vllm_runner, "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", tensor_parallel_size=1) + +def test_blackwell_fp8_block_moe_deep_gemm(vllm_runner, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1") + can_initialize(vllm_runner, "deepseek-ai/DeepSeek-V3.1", tensor_parallel_size=1) + +def test_blackwell_nvfp4_moe_flashinfer_cutlass(vllm_runner, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") + can_initialize(vllm_runner, "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", tensor_parallel_size=1) + +def test_blackwell_nvfp4_moe_flashinfer_trtllm(vllm_runner, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") + can_initialize(vllm_runner, "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", tensor_parallel_size=1) + can_initialize(vllm_runner, "nvidia/DeepSeek-R1-0528-FP4-v2", tensor_parallel_size=1) + diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index fe586a22e250..74bcffd8ca03 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -40,6 +40,8 @@ def flashinfer_fused_moe_blockscale_fp8( assert global_num_experts % 4 == 0 assert top_k < (topk_group * global_num_experts / num_expert_group) assert block_shape == [128, 128] + # Routing kernel expects #experts <= #threads 256 + assert global_num_experts <= 256 a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) # NOTE: scales of hidden states have to be transposed! From f8fa0ce51ddef80d846a2c8314f73d700f57468c Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 25 Sep 2025 16:25:24 -0700 Subject: [PATCH 2/9] Add round of tests Signed-off-by: mgoin --- tests/quantization/test_blackwell_moe.py | 180 ++++++++++++++--------- tests/utils.py | 4 +- 2 files changed, 116 insertions(+), 68 deletions(-) diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index 01ec572051e8..ab21ddd62a4b 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -1,89 +1,135 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from functools import partial -from unittest.mock import patch +import json +import os import pytest -from transformers import PretrainedConfig -from vllm import LLM -from vllm.sampling_params import SamplingParams +from tests.utils import RemoteOpenAIServer from vllm.platforms import current_platform -from tests.utils import create_new_process_for_each_test if not current_platform.is_device_capability(100): pytest.skip("This test only runs on Blackwell GPUs (SM100).", allow_module_level=True) -def dummy_hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: - """ - Dummy HF overrides function used to create dummy model - with only minimum nums of layer. - """ - text_config = hf_config.get_text_config() - - # Do 4 backbone layers to include dense and MoE layers - text_config.update({ - "num_layers": 4, - "num_hidden_layers": 4, - }) - - if hasattr(hf_config, "vision_config"): - hf_config.vision_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - }) - # e.g.: ibm-granite/granite-speech-3.3-2b - if hasattr(hf_config, "encoder_config"): - hf_config.encoder_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - }) - # e.g.: Qwen/Qwen2-Audio-7B-Instruct - if hasattr(hf_config, "audio_config"): - hf_config.audio_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - "encoder_layers": 1, - }) - - return hf_config - - -# @create_new_process_for_each_test() -def can_initialize(vllm_runner, model_name: str, **model_kwargs): - - default_model_kwargs = { - "enforce_eager": True, - "trust_remote_code": True, - "max_model_len": 1024, - "gpu_memory_utilization": 0.8, - "load_format": "dummy", - # "hf_overrides": dummy_hf_overrides, - } - default_model_kwargs.update(model_kwargs) - - with vllm_runner(model_name, **default_model_kwargs) as llm: - sp = SamplingParams(temperature=0, max_tokens=2) - llm.generate("Hello, world!", sampling_params=sp) - -def test_blackwell_fp8_tensor_moe_flashinfer_trtllm(vllm_runner, monkeypatch: pytest.MonkeyPatch): +os.environ["FLASHINFER_NVCC_THREADS"] = "16" + +# dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4, +# "text_config": {"num_layers": 4, "num_hidden_layers": 4}} +dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4} + + +def can_initialize(model: str, extra_args: list[str]): + + # Server arguments + server_args = [ + "--max-model-len", + "8192", + "--max-num-batched-tokens", + "1024", + "--load-format", + "dummy", + "--trust-remote-code", + "--limit-mm-per-prompt", + json.dumps({"image": 0}), + *extra_args, + ] + + # Launch server and make a simple request + with RemoteOpenAIServer(model, + server_args, + max_wait_seconds=480, + override_hf_configs=dummy_hf_overrides) as server: + client = server.get_client() + # Make a simple request to verify the server works + completion = client.chat.completions.create( + model=model, + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Hello!" + }], + temperature=0, + max_completion_tokens=2, + ) + generated_text = completion.choices[0].message.content + assert generated_text is not None + + +## Llama4 ## + + +@pytest.mark.skip( + reason="This gets stuck during/after graph capture for some reason") +def test_llama4_fp8_tensor_moe_flashinfer_cutlass( + monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") - can_initialize(vllm_runner, "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", tensor_parallel_size=1) + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") + can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", []) + + +def test_llama4_fp8_tensor_moe_flashinfer_trtllm( + monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") + can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", []) + + +@pytest.mark.skip( + reason="This gets stuck during/after graph capture for some reason") +def test_llama4_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") + can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", []) + + +@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options") +def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") + can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", []) -def test_blackwell_fp8_block_moe_deep_gemm(vllm_runner, monkeypatch: pytest.MonkeyPatch): + +## DeepSeekV3 ## + + +def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1") - can_initialize(vllm_runner, "deepseek-ai/DeepSeek-V3.1", tensor_parallel_size=1) + can_initialize("deepseek-ai/DeepSeek-V3.1", []) -def test_blackwell_nvfp4_moe_flashinfer_cutlass(vllm_runner, monkeypatch: pytest.MonkeyPatch): + +def test_deepseek_nvfp4_moe_flashinfer_cutlass( + monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") - can_initialize(vllm_runner, "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", tensor_parallel_size=1) + can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", []) + -def test_blackwell_nvfp4_moe_flashinfer_trtllm(vllm_runner, monkeypatch: pytest.MonkeyPatch): +@pytest.mark.skip(reason="RuntimeError: routing_bias must be bfloat16.") +def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") - can_initialize(vllm_runner, "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", tensor_parallel_size=1) - can_initialize(vllm_runner, "nvidia/DeepSeek-R1-0528-FP4-v2", tensor_parallel_size=1) + can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", []) + + +## GPT-OSS ## + + +def test_gptoss_mxfp4bf16_moe_flashinfer(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1") + can_initialize("openai/gpt-oss-20b", []) + + +def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass( + monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "1") + can_initialize("openai/gpt-oss-20b", []) + +def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm( + monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1") + can_initialize("openai/gpt-oss-20b", []) diff --git a/tests/utils.py b/tests/utils.py index f630c57f46d8..ab6ccc7ad9f9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -91,8 +91,10 @@ def _start_server(self, model: str, vllm_serve_args: list[str], env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' if env_dict is not None: env.update(env_dict) + serve_cmd = ["vllm", "serve", model, *vllm_serve_args] + print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}") self.proc: subprocess.Popen = subprocess.Popen( - ["vllm", "serve", model, *vllm_serve_args], + serve_cmd, env=env, stdout=sys.stdout, stderr=sys.stderr, From 40fbf309cb0310e8cc43fefd6fd4604aa09001e5 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 25 Sep 2025 16:40:14 -0700 Subject: [PATCH 3/9] Cleanup Signed-off-by: mgoin --- .buildkite/test-pipeline.yaml | 4 ++++ tests/quantization/test_blackwell_moe.py | 1 + 2 files changed, 5 insertions(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6ab9df2dac88..373fad2c9ba1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -834,7 +834,11 @@ steps: gpu: b200 source_file_dependencies: - tests/evals/gpt_oss + - vllm/model_executor/models/deepseek_v2.py - vllm/model_executor/models/gpt_oss.py + - vllm/model_executor/models/llama4.py + - vllm/model_executor/layers/quantization/compressed_tensors + - vllm/model_executor/layers/quantization/modelopt.py - vllm/model_executor/layers/quantization/mxfp4.py - vllm/v1/attention/backends/flashinfer.py commands: diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index ab21ddd62a4b..57c421ff7011 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -71,6 +71,7 @@ def test_llama4_fp8_tensor_moe_flashinfer_cutlass( can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", []) +@pytest.mark.skip(reason="Takes too long to run") def test_llama4_fp8_tensor_moe_flashinfer_trtllm( monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") From e8ae290114b93aa301e6e9bb74356408c35a6217 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 25 Sep 2025 16:42:02 -0700 Subject: [PATCH 4/9] Fix trigger Signed-off-by: mgoin --- .buildkite/test-pipeline.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 373fad2c9ba1..5cdd2c388f37 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -833,10 +833,11 @@ steps: working_dir: "/vllm-workspace/" gpu: b200 source_file_dependencies: - - tests/evals/gpt_oss + - tests/quantization/test_blackwell_moe.py - vllm/model_executor/models/deepseek_v2.py - vllm/model_executor/models/gpt_oss.py - vllm/model_executor/models/llama4.py + - vllm/model_executor/layers/fused_moe - vllm/model_executor/layers/quantization/compressed_tensors - vllm/model_executor/layers/quantization/modelopt.py - vllm/model_executor/layers/quantization/mxfp4.py From 75ee6ca70fcb5a982d1fa7f3247a0ef549eb31f3 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 25 Sep 2025 17:03:32 -0700 Subject: [PATCH 5/9] Updates Signed-off-by: mgoin --- tests/quantization/test_blackwell_moe.py | 27 ++++++++++-------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index 57c421ff7011..cbd84501b189 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -43,27 +43,23 @@ def can_initialize(model: str, extra_args: list[str]): override_hf_configs=dummy_hf_overrides) as server: client = server.get_client() # Make a simple request to verify the server works - completion = client.chat.completions.create( + completion = client.completions.create( model=model, - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Hello!" - }], + prompt=["Hello!"], temperature=0, - max_completion_tokens=2, + max_tokens=2, ) - generated_text = completion.choices[0].message.content - assert generated_text is not None + print(completion) + assert completion.choices[0].text is not None ## Llama4 ## -@pytest.mark.skip( - reason="This gets stuck during/after graph capture for some reason") +@pytest.mark.skip(reason=( + "RuntimeError: run_moe() Expected a value of type " + "'Optional[List[Tensor]]' for argument '_9' but instead found type " + "'list'.")) def test_llama4_fp8_tensor_moe_flashinfer_cutlass( monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") @@ -71,7 +67,7 @@ def test_llama4_fp8_tensor_moe_flashinfer_cutlass( can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", []) -@pytest.mark.skip(reason="Takes too long to run") +@pytest.mark.skip(reason="Works, but takes too long to run") def test_llama4_fp8_tensor_moe_flashinfer_trtllm( monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") @@ -79,8 +75,7 @@ def test_llama4_fp8_tensor_moe_flashinfer_trtllm( can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", []) -@pytest.mark.skip( - reason="This gets stuck during/after graph capture for some reason") +@pytest.mark.skip(reason="Works, but takes too long to run") def test_llama4_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") From bdfbf321c1b66c10f7cf0a4ef685f76efa94ef94 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 25 Sep 2025 17:07:57 -0700 Subject: [PATCH 6/9] Update Signed-off-by: mgoin --- tests/quantization/test_blackwell_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index cbd84501b189..14f0e4055b81 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -104,7 +104,7 @@ def test_deepseek_nvfp4_moe_flashinfer_cutlass( can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", []) -@pytest.mark.skip(reason="RuntimeError: routing_bias must be bfloat16.") +@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options") def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") From 317f8a52bce8d78c93173afae2d8ef3ebc309afa Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 26 Sep 2025 08:20:54 -0600 Subject: [PATCH 7/9] Increase max wait time for RemoteOpenAIServer Increased max wait time for server to 600 seconds due to FlashInfer compile. --- tests/quantization/test_blackwell_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index 14f0e4055b81..67ebee256fe6 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -39,7 +39,7 @@ def can_initialize(model: str, extra_args: list[str]): # Launch server and make a simple request with RemoteOpenAIServer(model, server_args, - max_wait_seconds=480, + max_wait_seconds=600, # Due to FlashInfer compile override_hf_configs=dummy_hf_overrides) as server: client = server.get_client() # Make a simple request to verify the server works From 526e36af6f8d67f2bf56634b6ab7ee7d628628c6 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 26 Sep 2025 08:29:03 -0600 Subject: [PATCH 8/9] precommit --- tests/quantization/test_blackwell_moe.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index 67ebee256fe6..f947926bddbf 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -37,10 +37,11 @@ def can_initialize(model: str, extra_args: list[str]): ] # Launch server and make a simple request - with RemoteOpenAIServer(model, - server_args, - max_wait_seconds=600, # Due to FlashInfer compile - override_hf_configs=dummy_hf_overrides) as server: + with RemoteOpenAIServer( + model, + server_args, + max_wait_seconds=600, # Due to FlashInfer compile + override_hf_configs=dummy_hf_overrides) as server: client = server.get_client() # Make a simple request to verify the server works completion = client.completions.create( From ca15b33f1d580a1dbc03985bcba40bfa303a4753 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 26 Sep 2025 10:39:42 -0600 Subject: [PATCH 9/9] Update server arguments and prompt in tests --- tests/quantization/test_blackwell_moe.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index f947926bddbf..c021126720af 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -25,9 +25,9 @@ def can_initialize(model: str, extra_args: list[str]): # Server arguments server_args = [ "--max-model-len", - "8192", + "2048", "--max-num-batched-tokens", - "1024", + "256", "--load-format", "dummy", "--trust-remote-code", @@ -40,13 +40,13 @@ def can_initialize(model: str, extra_args: list[str]): with RemoteOpenAIServer( model, server_args, - max_wait_seconds=600, # Due to FlashInfer compile + max_wait_seconds=1000, # Due to FlashInfer compile override_hf_configs=dummy_hf_overrides) as server: client = server.get_client() # Make a simple request to verify the server works completion = client.completions.create( model=model, - prompt=["Hello!"], + prompt=["Hello, World!"], temperature=0, max_tokens=2, )