-
-
Couldn't load subscription status.
- Fork 10.8k
[WIP][Attention] Sharded kv-cache for MLA #22789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP][Attention] Sharded kv-cache for MLA #22789
Conversation
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> wip Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for sharded KV-cache in Multi-Layer Attention (MLA) to enable tensor parallelism. The changes are extensive, spanning from CUDA kernels to Python-level attention logic and configuration. The core idea is to shard the KV cache blocks across tensor parallel ranks and use all-gather/all-to-all communication patterns to compute the full attention output.
While the overall approach seems sound, my review identified several critical issues in the implementation, particularly within the new sharding logic in vllm/v1/attention/backends/mla/common.py. These include incorrect operator precedence, flawed mathematical formulas for calculating sequence lengths, Python syntax errors, and incorrect tensor reshaping that would lead to runtime errors or incorrect behavior. Given the WIP nature of this PR and the author's note about AI-generated code, these findings are not unexpected. Addressing these issues is crucial for the feature's correctness and functionality.
| # Per-request used blocks and mask across existing table width | ||
| blocks_per_req = (context_lens + | ||
| (B - 1)) // B # [num_prefills] | ||
| max_blocks_per_req = (+B - 1) // B |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| last_page_len = torch.where( | ||
| context_lens_cpu > 0 & partial_last & last_owned, | ||
| context_lens_cpu % B + 1, B) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two issues in this block:
- The condition
context_lens_cpu > 0 & partial_last & last_ownedis likely incorrect due to operator precedence. The bitwise AND&has higher precedence than>. You should wrap the comparison in parentheses:(context_lens_cpu > 0) & .... - The formula
context_lens_cpu % B + 1forlast_page_lenis incorrect. For a context length that is a multiple ofB, it incorrectly returns 1 instead ofB. For other cases, it returns one more than the actual length. A correct formula to get the length of the last page is(context_lens_cpu - 1) % B + 1.
| last_page_len = torch.where( | |
| context_lens_cpu > 0 & partial_last & last_owned, | |
| context_lens_cpu % B + 1, B) | |
| last_page_len = torch.where( | |
| (context_lens_cpu > 0) & partial_last & last_owned, | |
| (context_lens_cpu - 1) % B + 1, B) |
| used_blocks_mask = torch.arange( | ||
| max_blocks_per_req, device=decode_block_table.device | ||
| ) < blocks_per_req.unsqueeze(1) | ||
| decode_block_table = decode_block_table[~used_blocks_mask] = -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line contains a syntax error. The assignment decode_block_table = decode_block_table[...] = -1 is not valid in Python. You probably intended to modify the tensor in-place.
| decode_block_table = decode_block_table[~used_blocks_mask] = -1 | |
| decode_block_table[~used_blocks_mask] = -1 |
| decode_block_table = decode_block_table[owned] | ||
| # Convert to local physical indices | ||
| decode_block_table = decode_block_table // self.tp_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indexing with a boolean mask owned will flatten the decode_block_table tensor. This is likely unintended, as the block table is expected to be a 2D tensor for subsequent operations. You might want to use torch.where to preserve the tensor's shape, similar to the logic in the prefill path (line 658).
| decode_block_table = decode_block_table[owned] | |
| # Convert to local physical indices | |
| decode_block_table = decode_block_table // self.tp_size | |
| decode_block_table = torch.where( | |
| owned, decode_block_table // self.tp_size, -1) |
| parts = context_output_local.view(B, self.num_heads, D) \ | ||
| .view(B, -1, heads_owned, D).movedim(1, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.num_heads refers to the number of local attention heads, but context_output_local at this point contains the output for all heads (global) because it was computed with q_all. You should use self.num_global_heads here to correctly reshape the tensor.
| parts = context_output_local.view(B, self.num_heads, D) \ | |
| .view(B, -1, heads_owned, D).movedim(1, 0) | |
| parts = context_output_local.view(B, self.num_global_heads, D) \ | |
| .view(B, -1, heads_owned, D).movedim(1, 0) |
|
This pull request has merge conflicts that must be resolved before it can be |
|
superseded by: #23734 |
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.Purpose
NOTE!: this is partially AI generated so the code is not yet working and is pretty ugly
See number 3 in: https://docs.google.com/document/d/1L4MmOA3JnVlahjZq5CsQhB0d3G5SayuSR5Y24rKdGpU/edit?usp=sharing
Test Plan
Test Result
(Optional) Documentation Update