Skip to content

Conversation

@varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Oct 29, 2025

Purpose

Running the Flashinfer autotuner when,

  • using data-parallel or tensor-parallel, and
  • using a flashinfer mxfp4 backend, and
  • eager-mode

Causes the engine startup to fail with,

(EngineCore_DP1 pid=3074289)   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
(EngineCore_DP1 pid=3074289)     return forward_call(*args, **kwargs)
(EngineCore_DP1 pid=3074289)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP1 pid=3074289)   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 1168, in forward
(EngineCore_DP1 pid=3074289)     fused_out = self._fused_experts(
(EngineCore_DP1 pid=3074289)                 ^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP1 pid=3074289)   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 1021, in _fused_experts
(EngineCore_DP1 pid=3074289)     self.fused_experts.apply(
(EngineCore_DP1 pid=3074289)   File "/home/varun-sundar-rabindranath/code/vllm/vllm/model_executor/layers/fused_moe/trtllm_moe.py", line 135, in apply
(EngineCore_DP1 pid=3074289)     trtllm_fp4_block_scale_routed_moe(**kwargs)
(EngineCore_DP1 pid=3074289)   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/flashinfer/fused_moe/core.py", line 1850, in trtllm_fp4_block_scale_routed_moe
(EngineCore_DP1 pid=3074289)     return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe(
(EngineCore_DP1 pid=3074289)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP1 pid=3074289)   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/flashinfer/fused_moe/core.py", line 1348, in trtllm_fp4_block_scale_moe_op
(EngineCore_DP1 pid=3074289)     _, tactic = tuner.choose_one(
(EngineCore_DP1 pid=3074289)                 ^^^^^^^^^^^^^^^^^
(EngineCore_DP1 pid=3074289)   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/flashinfer/autotuner.py", line 457, in choose_one
(EngineCore_DP1 pid=3074289)     profiles = self._generate_optimization_profiles(tuning_config, inputs)
(EngineCore_DP1 pid=3074289)                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP1 pid=3074289)   File "/home/varun-sundar-rabindranath/code/vllm/vllm-test/lib/python3.12/site-packages/flashinfer/autotuner.py", line 643, in _generate_optimization_profiles
(EngineCore_DP1 pid=3074289)     assert len(opt_shapes) > 0, "Empty tuning buckets are not allowed"
(EngineCore_DP1 pid=3074289)            ^^^^^^^^^^^^^^^^^^^
(EngineCore_DP1 pid=3074289) AssertionError: Empty tuning buckets are not allowed

This error was initially thought to be related to DP + certain choices for the MXFP4 backend. This PR updates the skip condition.

Test Plan and Test Result

B200 DP

VLLM_ALL2ALL_BACKEND="deepep_high_throughput" canhazgpu run -g2 -- vllm serve openai/gpt-oss-20b --data-parallel-size 2 --tensor-parallel-size 1 --enable-expert-parallel --no-enable-prefix-caching --port 9010

  • PR Pass
  • main Fail

VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1 canhazgpu run -g2 -- vllm serve openai/gpt-oss-20b --data-parallel-size 2 --tensor-parallel-size 1 --enable-expert-parallel --no-enable-prefix-caching --port 9010 --enforce-eager

  • PR Pass
  • main Pass

VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 canhazgpu run -g2 -- vllm serve openai/gpt-oss-20b --data-parallel-size 2 --tensor-parallel-size 1 --enable-expert-parallel --no-enable-prefix-caching --port 9010 --enforce-eager

  • PR Pass
  • main fail

VLLM_USE_FLASHINFER_MOE_MXFP4_BF16=1 canhazgpu run -g2 -- vllm serve openai/gpt-oss-20b --data-parallel-size 2 --tensor-parallel-size 1 --enable-expert-parallel --no-enable-prefix-caching --port 9010 --enforce-eager

  • PR Pass
  • main pass

B200 TP

VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1 canhazgpu run -g2 -- vllm serve openai/gpt-oss-20b --data-parallel-size 1 --tensor-parallel-size 2 --no-enable-prefix-caching --port 9010 --enforce-eager

  • PR Pass
  • main fail

VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 canhazgpu run -g2 -- vllm serve openai/gpt-oss-20b --data-parallel-size 1 --tensor-parallel-size 2 --no-enable-prefix-caching --port 9010 --enforce-eager

  • PR Pass
  • main fail

VLLM_USE_FLASHINFER_MOE_MXFP4_BF16=1 canhazgpu run -g2 -- vllm serve openai/gpt-oss-20b --data-parallel-size 1 --tensor-parallel-size 2 --no-enable-prefix-caching --port 9010 --enforce-eager

  • PR Pass
  • main fail

H100 DP

VLLM_USE_FLASHINFER_MOE_MXFP4_BF16=1 canhazgpu run -g2 -- vllm serve openai/gpt-oss-20b --data-parallel-size 2 --tensor-parallel-size 1 --enable-expert-parallel --no-enable-prefix-caching --port 9010 --enforce-eager

  • PR pass
  • main pass

H100 TP

VLLM_USE_FLASHINFER_MOE_MXFP4_BF16=1 canhazgpu run -g2 -- vllm serve openai/gpt-oss-20b --data-parallel-size 1 --tensor-parallel-size 2 --no-enable-prefix-caching --port 9010 --enforce-eager

  • PR pass
  • main fail

Varun Sundar Rabindranath added 2 commits October 29, 2025 10:56
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug where the Flashinfer autotuner fails during engine startup under specific conditions: when using data or tensor parallelism with a flashinfer mxfp4 backend in eager mode. The fix correctly identifies this problematic configuration by checking for tensor or data parallelism, the usage of any flashinfer mxfp4 backend, and whether the execution mode is eager. The changes are well-implemented, using clearly named boolean variables to improve readability of the condition. The fix appears correct and complete based on the problem description and test results.

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
@varun-sundar-rabindranath
Copy link
Contributor Author

@nvpohanh I verified that with the PR,

VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1   vllm serve openai/gpt-oss-20b --data-parallel-size 2 --tensor-parallel-size 1 --enable-expert-parallel   --no-enable-prefix-caching  --port 9010

flashinfer autotuning happens.

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
@varun-sundar-rabindranath
Copy link
Contributor Author

cc @zyongye @mgoin PTAL! Thanks.

@mgoin mgoin added bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed labels Oct 30, 2025
@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Oct 30, 2025
@nvpohanh
Copy link
Contributor

@elvischenv has also verified this. Thanks for the quick fix!

@nvpohanh
Copy link
Contributor

@mgoin It would be great if we can include this in v0.11.1 to avoid unexpected performance regression. Thanks!

@mgoin
Copy link
Member

mgoin commented Oct 30, 2025

@vllm-bot vllm-bot merged commit e5e076c into vllm-project:main Oct 30, 2025
46 of 48 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working gpt-oss Related to GPT-OSS models ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants