-
-
Notifications
You must be signed in to change notification settings - Fork 12.4k
[FlashInfer] Truncate block tables for sliding window attention #23010
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an optimization for sliding window attention in FlashInfer by truncating block tables. The overall logic is sound, but I've identified a critical issue where the sequence lengths are not updated to match the truncated block tables. This mismatch could lead to incorrect attention results or memory access errors. I have provided a detailed comment and a code suggestion to address this problem.
| decode_seq_lens_cpu = seq_lens_cpu[:num_decodes] | ||
| num_skipped_pages = ( | ||
| torch.relu(decode_seq_lens_cpu - self.sliding_window) // | ||
| page_size) | ||
|
|
||
| block_table_bounds_cpu[:num_decodes] -= num_skipped_pages | ||
| mask[:num_decodes] &= (arange[:num_decodes] | ||
| >= num_skipped_pages.unsqueeze(1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sequence lengths (seq_lens and seq_lens_cpu) are not updated to reflect the truncated block tables for sliding window attention. This can lead to incorrect attention computation or out-of-bounds memory access in the FlashInfer kernel because the kernel will receive mismatched sequence lengths and KV blocks.
The sequence lengths should be reduced by the number of tokens in the skipped pages. Note that if common_attn_metadata is shared across layers, seq_lens and seq_lens_cpu should be cloned before modification to avoid side effects.
decode_seq_lens_cpu = seq_lens_cpu[:num_decodes]
num_skipped_pages = (
torch.relu(decode_seq_lens_cpu - self.sliding_window) //
page_size)
skipped_tokens = num_skipped_pages * page_size
seq_lens_cpu[:num_decodes] -= skipped_tokens
seq_lens[:num_decodes] -= skipped_tokens.to(self.device,
non_blocking=True)
block_table_bounds_cpu[:num_decodes] -= num_skipped_pages
mask[:num_decodes] &= (arange[:num_decodes]
>= num_skipped_pages.unsqueeze(1))There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. To my understanding, it is not necessary to update seq_lens_cpu in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To my understanding, it is not necessary to update seq_lens_cpu in this case.
This is true for the "standard" FlashInfer code path which infers seqlen from paged_kv_last_page_len and paged_kv_indptr (so manipulating block_table_bounds_cpu is sufficient). Im not sure if this true for the new trtllm codepaths (i.e. trtllm_batch_decode_with_kv_cache) maybe someone from Nvidia can weigh in, @kushanam @pavanimajety
|
This pull request has merge conflicts that must be resolved before it can be |
…ix __syncthreads in mergestate (#1661) <!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> This PR skips redundant kv-tiles loading&computation in FA2's sliding window implementation, which is used by GPT-OSS inference scenarios. Main modification: 1. Make `PrefillPlanSM80` aware of sliding window: skipping redundant attention states merge 2. Add `kv-start-idx` offsets at the start of each device kernel, to skip the window computation. As it maintains the original 0-base `kv-tile-idx`, it should be compatible w/ jit `sink-token` implementation. <img width="717" height="320" alt="image" src="https://github.com/user-attachments/assets/a061c602-2cf9-4867-ad1b-46905ab1f44d" /> ## 🔍 Related Issues This should backup vllm-project/vllm#23010. <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: happierpig <zhaoyilong217@sjtu.edn.cn>
|
Am I right in thinking that the sliding window attention in this case will be page/block granular, not token granular? |
- Enables Sliding Window Attention (SWA) for the CPU backend. - Enables models with hybrid local-global attention (i.e., those with multiple KV cache groups) for the CPU backend. - Fixes a bug in `TorchSDPAMetadataBuilderV1` where `query_start_loc_cpu` was updated in-place despite being shared between multiple KV cache groups. - Adds a unit test for the decode phase in the CPU attention backend. - Adds reference implementations used to test paged attention to a common utils module and reuses it for the decode attention CPU test. SWA is enabled by truncating full blocks/pages that are outside the window (similar to vllm-project#23010). Hence, the implementation is block-granular and may attend to at most (block_size - 1) extra tokens; this did not affect the accuracy tests I ran. This is the best we can do with the current paged-attention API/implementation, which does not support SWA natively See: vllm-project#4768 for context and the attempt to enable native SWA support in paged_attn. Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
- Enables Sliding Window Attention (SWA) for the CPU backend. - Enables models with hybrid local-global attention (i.e., those with multiple KV cache groups) for the CPU backend. - Fixes a bug in `TorchSDPAMetadataBuilderV1` where `query_start_loc_cpu` was updated in-place despite being shared between multiple KV cache groups. - Adds a unit test for the decode phase in the CPU attention backend. - Adds reference implementations used to test paged attention to a common utils module and reuses it for the decode attention CPU test. SWA is enabled by truncating full blocks/pages that are outside the window (similar to vllm-project#23010). Hence, the implementation is block-granular and may attend to at most (block_size - 1) extra tokens; this did not affect the accuracy tests I ran. This is the best we can do with the current paged-attention API/implementation, which does not support SWA natively See: vllm-project#4768 for context and the attempt to enable native SWA support in paged_attn. Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
- Enables Sliding Window Attention (SWA) for the CPU backend. - Enables models with hybrid local-global attention (i.e., those with multiple KV cache groups) for the CPU backend. - Fixes a bug in `TorchSDPAMetadataBuilderV1` where `query_start_loc_cpu` was updated in-place despite being shared between multiple KV cache groups. - Adds a unit test for the decode phase in the CPU attention backend. - Adds reference implementations used to test paged attention to a common utils module and reuses it for the decode attention CPU test. SWA is enabled by truncating full blocks/pages that are outside the window (similar to vllm-project#23010). Hence, the implementation is block-granular and may attend to at most (block_size - 1) extra tokens; this did not affect the accuracy tests I ran. This is the best we can do with the current paged-attention API/implementation, which does not support SWA natively See: vllm-project#4768 for context and the attempt to enable native SWA support in paged_attn. Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
- Enables Sliding Window Attention (SWA) for the CPU backend. - Enables models with hybrid local-global attention (i.e., those with multiple KV cache groups) for the CPU backend. - Fixes a bug in `TorchSDPAMetadataBuilderV1` where `query_start_loc_cpu` was updated in-place despite being shared between multiple KV cache groups. - Adds a unit test for the decode phase in the CPU attention backend. - Adds reference implementations used to test paged attention to a common utils module and reuses it for the decode attention CPU test. SWA is enabled by truncating full blocks/pages that are outside the window (similar to vllm-project#23010). Hence, the implementation is block-granular and may attend to at most (block_size - 1) extra tokens; this did not affect the accuracy tests I ran. This is the best we can do with the current paged-attention API/implementation, which does not support SWA natively See: vllm-project#4768 for context and the attempt to enable native SWA support in paged_attn. Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
- Enables Sliding Window Attention (SWA) for the CPU backend. - Enables models with hybrid local-global attention (i.e., those with multiple KV cache groups) for the CPU backend. - Fixes a bug in `TorchSDPAMetadataBuilderV1` where `query_start_loc_cpu` was updated in-place despite being shared between multiple KV cache groups. - Adds a unit test for the decode phase in the CPU attention backend. - Adds reference implementations used to test paged attention to a common utils module and reuses it for the decode attention CPU test. SWA is enabled by truncating full blocks/pages that are outside the window (similar to vllm-project#23010). Hence, the implementation is block-granular and may attend to at most (block_size - 1) extra tokens; this did not affect the accuracy tests I ran. This is the best we can do with the current paged-attention API/implementation, which does not support SWA natively See: vllm-project#4768 for context and the attempt to enable native SWA support in paged_attn. Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
An optimization in using Flashinfer for sliding window attention.
FlashInfer's decode kernel doesn't seem to skip the kv outside the sliding window and only do masking (at least on Hopper GPUs; not sure about blackwell gpus).
Therefore, this PR manually manipulates the block table so that the skipping happens automatically.