Skip to content

Conversation

@yewentao256
Copy link
Member

@yewentao256 yewentao256 commented Oct 9, 2025

Purpose

Fixes #26529

Test

Originally:

(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/v1/worker/gpu_model_runner.py", line 3437, in _dummy_run
(EngineCore_DP5 pid=3148182)     outputs = self.model(
(EngineCore_DP5 pid=3148182)               ^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(EngineCore_DP5 pid=3148182)     return self._call_impl(*args, **kwargs)
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(EngineCore_DP5 pid=3148182)     return forward_call(*args, **kwargs)
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/model_executor/models/deepseek_v2.py", line 1353, in forward
(EngineCore_DP5 pid=3148182)     hidden_states = self.model(
(EngineCore_DP5 pid=3148182)                     ^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/compilation/decorators.py", line 228, in __call__
(EngineCore_DP5 pid=3148182)     return self.forward(*args, **kwargs)
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/model_executor/models/deepseek_v2.py", line 1230, in forward
(EngineCore_DP5 pid=3148182)     hidden_states, residual = layer(positions, hidden_states, residual)
(EngineCore_DP5 pid=3148182)                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(EngineCore_DP5 pid=3148182)     return self._call_impl(*args, **kwargs)
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(EngineCore_DP5 pid=3148182)     return forward_call(*args, **kwargs)
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/model_executor/models/deepseek_v2.py", line 1145, in forward
(EngineCore_DP5 pid=3148182)     hidden_states = self.mlp(hidden_states)
(EngineCore_DP5 pid=3148182)                     ^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(EngineCore_DP5 pid=3148182)     return self._call_impl(*args, **kwargs)
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(EngineCore_DP5 pid=3148182)     return forward_call(*args, **kwargs)
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/model_executor/models/deepseek_v2.py", line 257, in forward
(EngineCore_DP5 pid=3148182)     fused_moe_out = self.experts(
(EngineCore_DP5 pid=3148182)                     ^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(EngineCore_DP5 pid=3148182)     return self._call_impl(*args, **kwargs)
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(EngineCore_DP5 pid=3148182)     return forward_call(*args, **kwargs)
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/model_executor/layers/fused_moe/shared_fused_moe.py", line 61, in forward
(EngineCore_DP5 pid=3148182)     fused_out = super().forward(
(EngineCore_DP5 pid=3148182)                 ^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/model_executor/custom_op.py", line 47, in forward
(EngineCore_DP5 pid=3148182)     return self._forward_method(*args, **kwargs)
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/model_executor/layers/fused_moe/layer.py", line 2051, in forward_cuda
(EngineCore_DP5 pid=3148182)     return self.forward_native(hidden_states, router_logits)
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/model_executor/layers/fused_moe/layer.py", line 2026, in forward_native
(EngineCore_DP5 pid=3148182)     fused_output = torch.ops.vllm.moe_forward(
(EngineCore_DP5 pid=3148182)                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/.venv/lib/python3.12/site-packages/torch/_ops.py", line 1243, in __call__
(EngineCore_DP5 pid=3148182)     return self._op(*args, **kwargs)
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/model_executor/layers/fused_moe/layer.py", line 2367, in moe_forward
(EngineCore_DP5 pid=3148182)     return self.forward_impl(hidden_states, router_logits)
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/model_executor/layers/fused_moe/layer.py", line 2239, in forward_impl
(EngineCore_DP5 pid=3148182)     final_hidden_states = self.quant_method.apply(
(EngineCore_DP5 pid=3148182)                           ^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/model_executor/layers/quantization/fp8.py", line 1216, in apply
(EngineCore_DP5 pid=3148182)     result = self.fused_experts(
(EngineCore_DP5 pid=3148182)              ^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(EngineCore_DP5 pid=3148182)     return self._call_impl(*args, **kwargs)
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(EngineCore_DP5 pid=3148182)     return forward_call(*args, **kwargs)
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 1180, in forward
(EngineCore_DP5 pid=3148182)     return self._finalize(
(EngineCore_DP5 pid=3148182)            ^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 1066, in _finalize
(EngineCore_DP5 pid=3148182)     finalize_ret = self.prepare_finalize.finalize_async(
(EngineCore_DP5 pid=3148182)                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py", line 381, in finalize_async
(EngineCore_DP5 pid=3148182)     receiver = self._finalize(
(EngineCore_DP5 pid=3148182)                ^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/vllm-source/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py", line 339, in _finalize
(EngineCore_DP5 pid=3148182)     combined_x, _, event = self.buffer.combine(
(EngineCore_DP5 pid=3148182)                            ^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182)   File "/home/wentao/ep_kernels_workspace/DeepEP/deep_ep/buffer.py", line 413, in combine
(EngineCore_DP5 pid=3148182)     recv_x, recv_topk_weights, event = self.runtime.intranode_combine(
(EngineCore_DP5 pid=3148182)                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP5 pid=3148182) RuntimeError: Failed: Assertion error /home/wentao/ep_kernels_workspace/DeepEP/csrc/kernels/intranode.cu:928 'false and "Unsupported type"'

Now

(wentao) wentao@dgxB200-09:~$ curl -sS -i -H 'Content-Type: application/json' -X POST http://127.0.0.1:9256/v1/completions -d '{"model":"deepseek-ai/DeepSeek-R1","prompt":"Hello","max_tokens":10}'
HTTP/1.1 200 OK
date: Thu, 09 Oct 2025 20:29:23 GMT
server: uvicorn
content-length: 485
content-type: application/json

{"id":"cmpl-15015d51654148dfb0b28780ded3cb5f","object":"text_completion","cmodel":"deepseek-ai/DeepSeek-R1","choices":[{"index":0,"text":" 2020. So, wlogprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":nullull,"prompt_token_ids":null}],"service_tier":null,"system_fingerprint":nullkens":2,"total_tokens":12,"completion_tokens":10,"prompt_tokens_details":nu

…"Unsupported type"'

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 requested a review from mgoin as a code owner October 9, 2025 20:43
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical assertion error in the DeepEP high-throughput MoE kernel. The fix correctly identifies that the combine operation requires bfloat16 inputs and implements the necessary type casting. The input tensor is converted to bfloat16 before the operation, and the output is converted back to the original data type, ensuring correctness and compatibility with the rest of the model. These changes are applied consistently for both synchronous and asynchronous code paths. The fix is sound and effectively resolves the runtime crash.

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 10, 2025
combined_x, _, event = self.buffer.combine(
x=fused_expert_output,
# HT combine only supports BF16
x=fused_expert_output.to(torch.bfloat16),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the dtype of x when this error is occurring?

It seems like if this is happening, that means we're quantizing and then dequantizing the outputs, which would be inefficient and potentially cause poor output. Is there something we should be doing so that the experts produce a bf16 output instead? cc @bnellnm @varun-sundar-rabindranath

Copy link
Collaborator

@bnellnm bnellnm Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We definitely should not need to do this. Do you know which experts class is being used here? Debug level logging should show which variant is being used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also possible that we are hitting the first condition here in modular_kernel.py and the workspace13 tensor is the wrong type. You could try forcing it to use the False branch.

        # Construct the entire output that can then be processed in chunks.                                                                  
        # Reuse workspace13 for the output in the non-chunked case as long                                                                   
        # as it is large enough. This will not always be the case for standard                                                               
        # format experts and with experts that have empty workspaces.                                                                        
        if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape):
            fused_out = _resize_cache(workspace13, fused_out_shape)
        else:
            fused_out = buffers.fused_out.get(
                fused_out_shape, device=device, dtype=out_dtype
            )

Copy link
Member Author

@yewentao256 yewentao256 Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found the original cause and fix thoroughly.

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 merged commit 7200a21 into main Oct 13, 2025
53 of 54 checks passed
@yewentao256 yewentao256 deleted the wentao-fix-deepepHT-issue branch October 13, 2025 22:26
1994 pushed a commit to 1994/vllm that referenced this pull request Oct 14, 2025
…e and Unsupported type' (vllm-project#26532)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: 1994 <1994@users.noreply.github.com>
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
…e and Unsupported type' (vllm-project#26532)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
…e and Unsupported type' (vllm-project#26532)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: bbartels <benjamin@bartels.dev>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
…e and Unsupported type' (vllm-project#26532)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
…e and Unsupported type' (vllm-project#26532)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…e and Unsupported type' (vllm-project#26532)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…e and Unsupported type' (vllm-project#26532)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…e and Unsupported type' (vllm-project#26532)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…e and Unsupported type' (vllm-project#26532)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
…e and Unsupported type' (vllm-project#26532)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Zhathw pushed a commit to Zhathw/vllm that referenced this pull request Nov 12, 2025
…e and Unsupported type' (vllm-project#26532)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Assertion error DeepEP/csrc/kernels/intranode.cu:928: 'false and "Unsupported type"'

4 participants