-
-
Couldn't load subscription status.
- Fork 10.8k
[Kernel] cuda kernels for upcoming decode context parallel feature #23791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] cuda kernels for upcoming decode context parallel feature #23791
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces new CUDA kernels, cp_fused_concat_and_cache_mla and cp_gather_cache, to support an upcoming context parallel feature. The changes include the kernel implementations, their PyTorch bindings, and corresponding Python wrappers. Overall, the kernel implementations appear correct and follow existing patterns in the codebase. My main feedback is on the testing coverage. The new cp_fused_concat_and_cache_mla kernel is missing a unit test, and the test for cp_gather_cache is incomplete as it doesn't cover the key functionality it introduces. Adding comprehensive tests is crucial for ensuring the correctness and maintainability of these new kernels.
| def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, | ||
| num_blocks, max_seq_len, batch_size, dtype, | ||
| kv_cache_dtype, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test for cp_gather_cache is incomplete. The main purpose of this new kernel is to support arbitrary seq_starts, but the test only covers the case where seq_starts is None. Please add test cases that use non-zero seq_starts to validate this new functionality.
Additionally, the test only uses batch_size=8. The kernel has different logic for num_splits based on batch_size. It would be beneficial to test with a wider range of batch sizes to cover all branches, for example [8, 70, 130].
| ) | ||
|
|
||
| ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) | ||
| torch.testing.assert_close(dst, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM since this only adds two new kernels. cc @WoosukKwon @LucasWilkinson if you have more comments.
|
kernel tests passed, failed tests are unrelated. merging. |
Not really. AMD build failure is not unrelated |
…llm-project#23791) Co-authored-by: hongchao <hongchao@msh.team>
|
I am encountering an issue that is very likely caused by this PR when building on AMD MI300X. If I switch back to commit c07a733 (2 commits ahead), the error disappears. |
…llm-project#23791) Co-authored-by: hongchao <hongchao@msh.team>
…llm-project#23791) Co-authored-by: hongchao <hongchao@msh.team>
…#23847) Signed-off-by: charlifu <charlifu@amd.com>
…llm-project#23791) Co-authored-by: hongchao <hongchao@msh.team>
…#23847) Signed-off-by: charlifu <charlifu@amd.com>
Pre-PR for #1367
Suggestions from @youkaichao : to accelerate the review and merge (especially ci testing), maybe we can split the kernel side changes to a separate PR and get it merged first.