Skip to content

Conversation

@yewentao256
Copy link
Member

@yewentao256 yewentao256 commented Sep 15, 2025

Purpose

/home/wentao/vllm/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh(147): error: a value of type "const ElementBlockScale *" cannot be used to initialize an entity of type "uint32_t" (aka "unsigned int")
          a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
          ^
          detected during:
            instantiation of "void vllm::cutlass_gemm_caller_blockwise<Gemm>(at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &) [with Gemm=vllm::cutlass_3x_gemm_fp8_blockwise<cutlass::half_t, 1, 128, 128, cute::tuple<cute::_128, cute::_128, cute::_128>, cute::tuple<cute::_1, cute::_2, cute::_1>, cutlass::epilogue::TmaWarpSpecializedCooperative, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>]" at line 170
            instantiation of "void vllm::cutlass_gemm_blockwise_sm90_fp8_dispatch<OutType>(at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &) [with OutType=cutlass::half_t]" at line 19 of /home/wentao/vllm/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu
8 errors detected in the compilation of "/home/wentao/vllm/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu".
ninja: build stopped: subcommand failed.

Test

Now:

(wentao) wentao@H100-GPU17:~/vllm$ cmake --build --preset release --target install
[2/3] Install the project...
-- Install configuration: "Release"
-- Up-to-date: /home/wentao/vllm/vllm/cumem_allocator.abi3.so
-- Installing: /home/wentao/vllm/vllm/_C.abi3.so
-- Set non-toolchain portion of runtime path of "/home/wentao/vllm/vllm/_C.abi3.so" to ""
-- Up-to-date: /home/wentao/vllm/vllm/_moe_C.abi3.so
-- Up-to-date: /home/wentao/vllm/vllm/_flashmla_C.abi3.so
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/ops
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/ops/triton
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/ops/triton/rotary.py
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/ops/triton/__init__.py
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/layers
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/layers/rotary.py
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/layers/__init__.py
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/flash_attn_interface.py
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/__init__.py
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/ops
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/ops/triton
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/ops/triton/rotary.py
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/ops/triton/__init__.py
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/layers
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/layers/rotary.py
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/layers/__init__.py
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/flash_attn_interface.py
-- Up-to-date: /home/wentao/vllm/vllm/vllm_flash_attn/__init__.py

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively resolves a compilation error in the CUTLASS scaled matrix multiplication kernels. The primary fix involves replacing a problematic aggregate initialization of MainloopArguments with explicit member assignments, which enhances code clarity and robustness. Additionally, the changes correctly enforce const correctness for input tensor data pointers. The fix is well-implemented and consistently applied across different SM architectures. I have one suggestion to further improve code maintainability by reducing redundancy.

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me. Do you know what caused the issue locally for you while it works in CI? Are you using CUDA 13 or something?

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 15, 2025
@yewentao256
Copy link
Member Author

yewentao256 commented Sep 15, 2025

Seems reasonable to me. Do you know what caused the issue locally for you while it works in CI? Are you using CUDA 13 or something?

This can be reproduced in beaker (H100 CUDA Version: 12.8).
I am not sure why this doesn't trigger in CI, do we have CI specifically for Hopper? @mgoin

yewentao256 and others added 2 commits September 15, 2025 17:21
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@mgoin
Copy link
Member

mgoin commented Sep 15, 2025

@yewentao256 we don't have Hopper runners in CI, but we do build for the arch obviously and use cu128 by default. I would expect this to be an nvcc version issue, so that is strange..

@yewentao256
Copy link
Member Author

yewentao256 commented Sep 15, 2025

@yewentao256 we don't have Hopper runners in CI, but we do build for the arch obviously and use cu128 by default. I would expect this to be an nvcc version issue, so that is strange..

vllm/CMakeLists.txt
set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use")

Do we need to update here for a newer version of cutlass? @mgoin

@yewentao256
Copy link
Member Author

Let's get this landed since all CI tests passes, and avoid others meeting the same issue.
We can have following up PR/issues to discuss further about the version stuff.

@yewentao256 yewentao256 merged commit e757a62 into vllm-project:main Sep 15, 2025
80 checks passed
@yewentao256 yewentao256 deleted the wye-fix-cutlass-compilation-error branch September 15, 2025 21:21
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
QierLi pushed a commit to QierLi/vllm that referenced this pull request Oct 5, 2025
Signed-off-by: bbartels <benjamin@bartels.dev>

[gpt-oss] Add IncompleteDetails to ResponsesRepsonse (vllm-project#24561)

Signed-off-by: Andrew Xia <axia@meta.com>

[gpt-oss][1a] create_responses stream outputs BaseModel type, api server is SSE still (vllm-project#24759)

Signed-off-by: Andrew Xia <axia@meta.com>

[Performance] Remove redundant clone() calls in cutlass_mla (vllm-project#24891)

[Bug] Fix Cutlass Scaled MM Compilation Error (vllm-project#24887)

Signed-off-by: yewentao256 <zhyanwentao@126.com>

[ci] fix wheel names for arm wheels (vllm-project#24898)

Signed-off-by: simon-mo <simon.mo@hey.com>

[Tests] fix initialization of kv hash in tests (vllm-project#24273)

Signed-off-by: Mickael Seznec <mickael@mistral.ai>

[Compile] Fix noop_elimination pass and add tests for noop_elimination (vllm-project#24880)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>

Propagate entire tokens to connector for resumed preemptions

Signed-off-by: Qier Li <kevin44036@gmail.com>

Fix pre-commit

Signed-off-by: Qier Li <kevin44036@gmail.com>

Rename field and nullify empty lists

Signed-off-by: Qier Li <kevin44036@gmail.com>

Update vllm/v1/core/sched/scheduler.py

Co-authored-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Qier Li <kevin44036@gmail.com>

Add unit test for preemption resumption

Signed-off-by: Qier Li <kevin44036@gmail.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants