Skip to content

Conversation

@yewentao256
Copy link
Member

@yewentao256 yewentao256 commented Jul 11, 2025

Purpose

Fix DeepGemm for EP low latency case

Test

Original:

(EngineCore_7 pid=1248)   File "/app/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 717, in forward
(EngineCore_7 pid=1248)     fused_out = self._maybe_chunk_fused_experts(
(EngineCore_7 pid=1248)                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_7 pid=1248)   File "/app/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 526, in _maybe_chunk_fused_experts
(EngineCore_7 pid=1248)     return self._do_fused_experts(
(EngineCore_7 pid=1248)            ^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_7 pid=1248)   File "/app/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 488, in _do_fused_experts
(EngineCore_7 pid=1248)     self.fused_experts.apply(fused_out,
(EngineCore_7 pid=1248)   File "/app/vllm/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py", line 156, in apply
(EngineCore_7 pid=1248)     experts.apply(output, hidden_states, w1, w2, topk_ids, activation,
(EngineCore_7 pid=1248)   File "/app/vllm/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py", line 292, in apply
(EngineCore_7 pid=1248)     fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
(EngineCore_7 pid=1248)   File "/app/vllm/vllm/utils/deep_gemm.py", line 101, in fp8_m_grouped_gemm_nt_masked
(EngineCore_7 pid=1248)     return _grouped_masked_impl(*args, **kwargs)
(EngineCore_7 pid=1248)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_7 pid=1248) TypeError: fp8_m_grouped_gemm_nt_masked() got an unexpected keyword argument 'out'

Now:

export VLLM_ALL2ALL_BACKEND="deepep_low_latency"
VLLM_USE_DEEP_GEMM=1 lm_eval   --model vllm   --model_args "pretrained=Qwen/Qwen3-30B-A3B-FP8,data_parallel_size=2,max_model_len=32768,enable_expert_parallel=True,enforce_eager=True"   --trust_remote_code   --tasks gsm8k   --num_fewshot 5   --batch_size auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.8317|±  |0.0103|
|     |       |strict-match    |     5|exact_match||0.8294|±  |0.0104|

Vs no deepgemm:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.8294|±  |0.0104|
|     |       |strict-match    |     5|exact_match||0.8931|±  |0.0085|

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yewentao256, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request primarily addresses a critical bug that prevented DeepGemm from operating correctly in Expert Parallel (EP) low latency configurations, which manifested as a TypeError during grouped GEMM operations. In addition to the bug fix, it introduces an important optimization by adding a new FP8 quantization kernel tailored for NVIDIA Blackwell GPUs, and intelligently applies this optimization based on the detected hardware, aiming to enhance performance and stability on newer architectures.

Highlights

  • Bug Fix: Resolved a TypeError in the fp8_m_grouped_gemm_nt_masked function by adjusting argument passing from keyword to positional. This bug was causing crashes in DeepGemm when used in Expert Parallel (EP) low latency scenarios.
  • Hardware-Specific Optimization: Introduced a new Triton kernel, _silu_mul_fp8_quant_deep_gemm_ue8m0, specifically designed for FP8 quantization with UE8M0 (Unsigned E8M0) scaling. This is intended to optimize performance and compatibility on NVIDIA Blackwell GPUs.
  • Conditional Execution: Implemented logic to dynamically select between the standard FP8 quantization kernel and the new UE8M0-specific kernel based on whether a Blackwell GPU is detected using is_blackwell_deep_gemm_used(), ensuring the most appropriate and performant path is taken.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a TypeError when using DeepGEMM in the expert-parallel low-latency case. It also introduces support for Blackwell-specific UE8M0 quantization. The review suggests simplifying a conditional block to make the code more concise.

@smarterclayton
Copy link
Contributor

This fixed the failure I was seeing in the B200 DP=16,EP=16 2 node configuration.

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks LGTM

@mgoin mgoin enabled auto-merge (squash) July 11, 2025 18:58
@mgoin mgoin added bug Something isn't working deepseek Related to DeepSeek models labels Jul 11, 2025
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 11, 2025
@WoosukKwon
Copy link
Collaborator

Just curious: Isn't 6% accuracy diff in gsm8k-strict significant? Is it acceptable?

@yewentao256
Copy link
Member Author

Just curious: Isn't 6% accuracy diff in gsm8k-strict significant? Is it acceptable?

There is a auccracy loss for DeepGemm on B200 currently, deepseek-ai/DeepGEMM#112

If you run the unit test in test_core.py of their branch, you will find that they are now using an global comparison instead of per-element comparison, because that will fail.

So that's why we use

def calc_diff(x: torch.Tensor, y: torch.Tensor):
for the test.

But seems that doesn't affect too much for the R1, and they seems now trying to make it better, so we temporally do not care too much about that now.

@mgoin
Copy link
Member

mgoin commented Jul 11, 2025

Yeah what Wentao said is correct. It is the unfortunate result of DeepGEMM switching from float to E8M0 scales for SM100.

@vllm-bot vllm-bot merged commit 0d4891c into vllm-project:main Jul 12, 2025
74 of 76 checks passed
@yewentao256 yewentao256 deleted the wye/fix-batched-deepgemm-error branch July 23, 2025 18:11
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: x22x22 <wadeking@qq.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants