Skip to content

Conversation

@pavanimajety
Copy link
Collaborator

@pavanimajety pavanimajety commented Sep 29, 2025

Purpose

Fixes #25189: accuracy issues for TRTLLM DSR1 Latency kernels
Bugs introduced in #23991 and #23640

Also fixes incorrect logging prints for which kernels are used. When Flashinfer MOE is enabled, Deep Gemm is automatically disabled for SM100.

(EngineCore_DP0 pid=101213) (Worker_TP3 pid=101225) INFO 09-29 10:14:46 [fp8.py:454] Detected Blackwell GPUs, using FlashInfer TensorRT-LLM kernels for FP8 MOE.
(EngineCore_DP0 pid=101213) (Worker_TP3 pid=101225) INFO 09-29 10:14:46 [fp8.py:468] DeepGemm disabled: FlashInfer MOE is enabled.

Test Plan

LM Eval with VLLM_USE_FLASHINFER_MOE_FP8=1 and VLLM_FLASHINFER_MOE_BACKEND="latency"

root@umbriel-b200-017:/workspace/scratch-pmaj-1/gh-pm-vllm# ./run_lm_eval.sh                                                                        
+ export FORCE_NUM_KV_SPLITS=1                                                                                                                      
+ FORCE_NUM_KV_SPLITS=1                                                                                                                             
+ export VLLM_USE_FLASHINFER_MOE_FP8=1                                                                                                              
+ VLLM_USE_FLASHINFER_MOE_FP8=1                                                                                                                     
+ export VLLM_FLASHINFER_MOE_BACKEND=latency                                                                                                        
+ VLLM_FLASHINFER_MOE_BACKEND=latency                                                                                                               
+ echo 'Running lm_eval with:'                                                                                                                      
Running lm_eval with:                                                                                                                               
                     
  Model: /models/DSR1-FP8/models--deepseek-ai--DeepSeek-R1-0528/snapshots/4236a6af538feda4548eca9ab308586007567f52/                                 
  Tensor Parallel Size: 8                                                                                                                           
  Quantization: fp8                                                                                                                                 
  GPU Memory Utilization: 0.90                                                                                                                      
                                                                      
+ lm_eval --model vllm --model_args pretrained=/models/DSR1-FP8/models--deepseek-ai--DeepSeek-R1-0528/snapshots/4236a6af538feda4548eca9ab30858600756
7f52/,quantization=fp8,tensor_parallel_size=8,gpu_memory_utilization=0.90,add_bos_token=True --gen_kwargs temperature=0.0,max_gen_toks=32768 --trust
_remote_code --tasks gsm8k --num_fewshot 5 --batch_size 200 --limit 1319 --output_path results/4236a6af538feda4548eca9ab308586007567f52-gsm8k-1319sa
mples_20250929_101419 --log_samples

Test Result

vllm (pretrained=/models/DSR1-FP8/models--deepseek-ai--DeepSeek-R1-0528/snapshots/4236a6af538feda4548eca9ab308586007567f52/,quantization=fp8,tensor_parallel_size=8,gpu_memory_utilization=0.90,add_bos_token=True,trust_remote_code=True), gen_kwargs: (temperature=0.0,max_gen_toks=32768), limit: 1319.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9666|±  |0.0049|
|     |       |strict-match    |     5|exact_match|↑  |0.9651|±  |0.0051|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • [N/A] (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces important fixes for FP8 MoE, particularly for TRT-LLM latency kernels. The changes correctly address an accuracy issue by ensuring e_score_correction_bias has the expected dtype, and fix a critical control flow bug by adding an early return after the MoE operation is complete. Additionally, the logic for selecting between FlashInfer, DeepGEMM, and Cutlass kernels is improved to prevent conflicts, disabling other kernels when FlashInfer MoE is active. The logging has also been updated to be more informative about which kernel is being used. The changes are well-structured and significantly improve the correctness and robustness of the FP8 MoE implementation.

Comment on lines 944 to 946
e_score_correction_bias = (e_score_correction_bias.to(x.dtype)
if e_score_correction_bias is not None else None)
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This is a great fix that addresses two important issues:

  1. Casting e_score_correction_bias to x.dtype resolves a potential dtype mismatch that could lead to accuracy problems, which is critical for correctness.
  2. Adding the return statement corrects a significant control flow bug. Previously, the code would fall through and execute select_experts and other logic even after the complete MoE operation was performed by flashinfer_fused_moe_blockscale_fp8. This early return ensures the function exits correctly.

This change is crucial for both correctness and logic of the FP8 MoE path.

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Copy link
Collaborator

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed bug Something isn't working labels Sep 29, 2025
Comment on lines 463 to +468
elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
"DeepGemm kernels")
logger.warning_once("Model is not block quantized. Not using"
" DeepGemm kernels")
elif self.flashinfer_moe_backend:
logger.info_once("DeepGemm disabled: FlashInfer MOE is"
" enabled.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for future self: we should clean up these logs now that VLLM_USE_DEEP_GEMM=1 by default

else:
assert (not renormalize
and custom_routing_function is not None)
result = apply_flashinfer_per_tensor_scale_fp8(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a bug too? It seems like we need to return here rather than write to result

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it can potentially be a bug, havent' tested the path - so I was hestiant to change it.

@mgoin mgoin merged commit ef28354 into vllm-project:main Sep 30, 2025
49 of 50 checks passed
pdasigi pushed a commit to pdasigi/vllm that referenced this pull request Oct 2, 2025
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
…5895)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
tomeras91 pushed a commit to tomeras91/vllm that referenced this pull request Oct 6, 2025
…lm-project#25895)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…lm-project#25895)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…lm-project#25895)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: B200 FlashInfer FP8 MoE low-latency - incorrect results

3 participants