Skip to content

Conversation

lumina37
Copy link
Contributor

@lumina37 lumina37 commented Sep 15, 2025

Purpose

  1. Use cuda::std::isfinite instead of std::isfinite(__half2float(...)) to check the finiteness of input without extra type-casting.
  2. Use neg_inf<T> to generate -inf for type T to simplify code. However it seems that there's some bug with cuda::std::numeric_limits<T>::infinity() and cuda::std::numeric_limits<T>::lowest(). They always return 0.0 for [T=bf16 or fp16]. So I take a casting from fp32 -inf as the correct impl. My nvcc version is cuda_12.8.r12.8/compiler.35404655_0.
  3. (Misc) Add missing cg namespace for cg::reduce

Test Plan

pytest -s -v tests/kernels/moe/test_grouped_topk.py

Test Result

tests/kernels/moe/test_grouped_topk.py::test_grouped_topk[dtype0-1.0-softmax-2-8-True-2-16-1024-1] PASSED
tests/kernels/moe/test_grouped_topk.py::test_grouped_topk[dtype0-1.0-softmax-2-8-True-2-16-1024-33] PASSED
tests/kernels/moe/test_grouped_topk.py::test_grouped_topk[dtype0-1.0-softmax-2-8-True-2-16-1024-64] PASSED
...<more cases>

=========================================== warnings summary ============================================
.venv/lib/python3.12/site-packages/schemathesis/generation/coverage.py:305
  /workspace/tz/code/vllm/.venv/lib/python3.12/site-packages/schemathesis/generation/coverage.py:305: DeprecationWarning: jsonschema.exceptions.RefResolutionError is deprecated as of version 4.18.0. If you wish to catch potential reference resolution errors, directly catch referencing.exceptions.Unresolvable.
    ref_error: type[Exception] = jsonschema.RefResolutionError,

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================== 144 passed, 1 warning in 45.65s ====================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several good improvements for handling infinity and finiteness checks in the grouped topk CUDA kernel. The changes are well-motivated and correctly implemented. Using cuda::std::isfinite directly on half-precision types is more efficient, and the new neg_inf<T>() helper function improves type safety. The code is cleaner and more robust. I have one suggestion to improve the long-term maintainability of the neg_inf function by scoping the bug workaround to specific CUDA versions.

Comment on lines +414 to +419
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The workaround for the cuda::std::numeric_limits bug is noted. To improve long-term maintainability and ensure the code automatically benefits from future bug fixes in the CUDA toolkit, it's best to scope this workaround to the specific compiler versions where the bug is present using preprocessor directives. This will allow the simpler and more direct cuda::std::numeric_limits<T>::lowest() to be used once the underlying bug is resolved.

Additionally, using cuda::std::numeric_limits<float>::lowest() is more idiomatic for getting negative infinity than -cuda::std::numeric_limits<float>::infinity().

#if (defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)
// Workaround for a bug in libcu++ in CUDA 12.8+ where numeric_limits for
// fp16/bf16 types return 0.
template <typename T>
__device__ inline T neg_inf() {
  // cuda::std::numeric_limits<T>::infinity() and lowest() return `0` for [T=bf16 or fp16]
  // so we need to cast from fp32.
  return cuda_cast<T, float>(cuda::std::numeric_limits<float>::lowest());
}
#else
// Default implementation for compilers where the bug is not present.
template <typename T>
__device__ inline T neg_inf() {
  return cuda::std::numeric_limits<T>::lowest();
}
#endif

@lumina37 lumina37 force-pushed the grouped-topk-inf branch 2 times, most recently from b267663 to afcb182 Compare September 15, 2025 15:41
Signed-off-by: lumina37 <starry.qvq@gmail.com>
@lumina37
Copy link
Contributor Author

@jikunshang @xyang16 May I have a review for this minor change? It wont take much time.

@mayuyuace
Copy link
Contributor

Does this PR improve performance? If so, please provide detailed comparison results.

@lumina37
Copy link
Contributor Author

Does this PR improve performance? If so, please provide detailed comparison results.

@mayuyuace I'll run the benchmark later, but I suggest the performance diff is nearly unrecognizable.

@lumina37
Copy link
Contributor Author

@mayuyuace The benchmark result is almost the same

Command & Env

python benchmarks/kernels/benchmark_moe.py --model deepseek-ai/DeepSeek-V3

Run on my 3090*2 instance.

Main Branch

Batch size: 1, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 423.12 us
Batch size: 2, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 840.17 us
Batch size: 4, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 1580.68 us
Batch size: 8, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 2908.79 us
Batch size: 16, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 5107.58 us
Batch size: 24, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 6827.53 us
Batch size: 32, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 8156.11 us
Batch size: 48, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 10000.19 us
Batch size: 64, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 11172.37 us
Batch size: 96, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 12235.45 us
Batch size: 128, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 13336.55 us
Batch size: 256, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 13082.13 us
Batch size: 512, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 23994.93 us
Batch size: 1024, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 25080.60 us
Batch size: 1536, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 31592.08 us
Batch size: 2048, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 36259.47 us
Batch size: 3072, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 65925.27 us
Batch size: 4096, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 59805.92 us

This PR

Batch size: 1, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 423.35 us
Batch size: 2, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 840.05 us
Batch size: 4, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 1580.89 us
Batch size: 8, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 2909.20 us
Batch size: 16, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 5107.83 us
Batch size: 24, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 6827.95 us
Batch size: 32, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 8159.69 us
Batch size: 48, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 10000.39 us
Batch size: 64, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 11196.69 us
Batch size: 96, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 12235.43 us
Batch size: 128, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 14001.90 us
Batch size: 256, config: {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}
Kernel time: 13082.59 us
Batch size: 512, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 24673.28 us
Batch size: 1024, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 25087.66 us
Batch size: 1536, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 33082.01 us
Batch size: 2048, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 36259.31 us
Batch size: 3072, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 68464.85 us
Batch size: 4096, config: {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
Kernel time: 59731.16 us

Copy link
Collaborator

@jikunshang jikunshang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. thanks for contributing!

@jikunshang jikunshang enabled auto-merge (squash) September 17, 2025 00:12
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 17, 2025
@jikunshang jikunshang merged commit 81b16a2 into vllm-project:main Sep 18, 2025
91 checks passed
845473182 pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Sep 18, 2025
…litPR into model_register

* 'model_register' of https://github.com/dsxsteven/vllm_splitPR: (138 commits)
  Retrieve `sliding_window` from text config in Gemma3 MM (vllm-project#25085)
  [Docs] Fix API Reference (vllm-project#25140)
  [Kernel] Better inf handling for grouped topk cu (vllm-project#24886)
  [CLI] Use streaming in CLI chat and completion commands (vllm-project#23769)
  [benchmark] add peak throughput metrics and plot (vllm-project#23867)
  [Spec Decode] Efficient padded speculation (vllm-project#24539)
  [V0 Deprecation] Remove more V0 tests (vllm-project#25117)
  [EPLB] Add EPLB support for hunyuan_v1 (vllm-project#23078)
  [XPU] Whisper model support on XPU Platform (vllm-project#25123)
  Mark prompt logprobs as incompatible with prompt embeds at API level (vllm-project#25077)
  [Model] enable data parallel for InternVL vision encoder (vllm-project#23909)
  [Kernels] Overlap shared experts with combine instead of dispatch (vllm-project#24254)
  [Bugfix][Qwen3-Next] add prefixes to shared_expert in qwen3-next and mlp in qwen2moe to successfully load ignored params in quantized models (vllm-project#24960)
  [Core][MM] Cleanup `MultiModalCache` (vllm-project#25006)
  [Docs] Clean up the contributing README (vllm-project#25099)
  [MM Encoder] Apply DP ViT for Qwen3-VL model series (vllm-project#24955)
  [Kernels] Enable DeepGEMM by default (vllm-project#24462)
  [V0 Deprecation] Skip PP test (vllm-project#25128)
  [V0 Deprecation] Remove misc V0 tests (vllm-project#25118)
  [V0 Deprecation] Remove V0 Tracing & Metrics tests (vllm-project#25115)
  ...
debroy-rh pushed a commit to debroy-rh/vllm that referenced this pull request Sep 19, 2025
Signed-off-by: lumina37 <starry.qvq@gmail.com>
@acdart
Copy link

acdart commented Sep 22, 2025

@jikunshang @lumina37

errors:

/root/vllm/csrc/moe/grouped topk kernels. cu(536): error: namespace "cuda: :sta" has no member "isfinite"

@lumina37
Copy link
Contributor Author

@acdart What's your nvcc version? BTW is that "cuda::std" instead of "cuda::sta"?

@zhuohan123
Copy link
Member

@lumina37 getting the following error in compilation:

      /mnt/code/vllm/csrc/moe/grouped_topk_kernels.cu(536): error: namespace "cuda::std" has no member "isfinite"
            if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) {
                                                ^
      
      /mnt/code/vllm/csrc/moe/grouped_topk_kernels.cu(573): error: namespace "cuda::std" has no member "isfinite"
                              cuda::std::isfinite(scores_with_bias[offset + i])
                                         ^
      
      2 errors detected in the compilation of "/mnt/code/vllm/csrc/moe/grouped_topk_kernels.cu".

Does this PR changes the requirement on nvcc version?

@zhuohan123
Copy link
Member

Related fix: #25346

@lumina37
Copy link
Contributor Author

I apologize for the mistake caused by my insufficient testing. The cuda::std::isfinite function exists in libcu++ v2.1.0 and was released with CUDA 12.2, but I'm not sure why it's still missing in CUDA 12.6. I will test on more versions next time.

FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: lumina37 <starry.qvq@gmail.com>
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: lumina37 <starry.qvq@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: lumina37 <starry.qvq@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
Signed-off-by: lumina37 <starry.qvq@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants