-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Bug] Fix Assertion error DeepEP/csrc/kernels/intranode.cu:928: 'false and Unsupported type' #26532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…"Unsupported type"' Signed-off-by: yewentao256 <zhyanwentao@126.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a critical assertion error in the DeepEP high-throughput MoE kernel. The fix correctly identifies that the combine operation requires bfloat16 inputs and implements the necessary type casting. The input tensor is converted to bfloat16 before the operation, and the output is converted back to the original data type, ensuring correctness and compatibility with the rest of the model. These changes are applied consistently for both synchronous and asynchronous code paths. The fix is sound and effectively resolves the runtime crash.
| combined_x, _, event = self.buffer.combine( | ||
| x=fused_expert_output, | ||
| # HT combine only supports BF16 | ||
| x=fused_expert_output.to(torch.bfloat16), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the dtype of x when this error is occurring?
It seems like if this is happening, that means we're quantizing and then dequantizing the outputs, which would be inefficient and potentially cause poor output. Is there something we should be doing so that the experts produce a bf16 output instead? cc @bnellnm @varun-sundar-rabindranath
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We definitely should not need to do this. Do you know which experts class is being used here? Debug level logging should show which variant is being used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's also possible that we are hitting the first condition here in modular_kernel.py and the workspace13 tensor is the wrong type. You could try forcing it to use the False branch.
# Construct the entire output that can then be processed in chunks.
# Reuse workspace13 for the output in the non-chunked case as long
# as it is large enough. This will not always be the case for standard
# format experts and with experts that have empty workspaces.
if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape):
fused_out = _resize_cache(workspace13, fused_out_shape)
else:
fused_out = buffers.fused_out.get(
fused_out_shape, device=device, dtype=out_dtype
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found the original cause and fix thoroughly.
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
…e and Unsupported type' (vllm-project#26532) Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: 1994 <1994@users.noreply.github.com>
…e and Unsupported type' (vllm-project#26532) Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
…e and Unsupported type' (vllm-project#26532) Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: bbartels <benjamin@bartels.dev>
…e and Unsupported type' (vllm-project#26532) Signed-off-by: yewentao256 <zhyanwentao@126.com>
…e and Unsupported type' (vllm-project#26532) Signed-off-by: yewentao256 <zhyanwentao@126.com>
…e and Unsupported type' (vllm-project#26532) Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…e and Unsupported type' (vllm-project#26532) Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…e and Unsupported type' (vllm-project#26532) Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…e and Unsupported type' (vllm-project#26532) Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…e and Unsupported type' (vllm-project#26532) Signed-off-by: yewentao256 <zhyanwentao@126.com>
…e and Unsupported type' (vllm-project#26532) Signed-off-by: yewentao256 <zhyanwentao@126.com>
Purpose
Fixes #26529
Test
Originally:
Now