Skip to content

Conversation

@WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented Aug 15, 2025

An optimization in using Flashinfer for sliding window attention.
FlashInfer's decode kernel doesn't seem to skip the kv outside the sliding window and only do masking (at least on Hopper GPUs; not sure about blackwell gpus).
Therefore, this PR manually manipulates the block table so that the skipping happens automatically.

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Aug 15, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimization for sliding window attention in FlashInfer by truncating block tables. The overall logic is sound, but I've identified a critical issue where the sequence lengths are not updated to match the truncated block tables. This mismatch could lead to incorrect attention results or memory access errors. I have provided a detailed comment and a code suggestion to address this problem.

Comment on lines +502 to +509
decode_seq_lens_cpu = seq_lens_cpu[:num_decodes]
num_skipped_pages = (
torch.relu(decode_seq_lens_cpu - self.sliding_window) //
page_size)

block_table_bounds_cpu[:num_decodes] -= num_skipped_pages
mask[:num_decodes] &= (arange[:num_decodes]
>= num_skipped_pages.unsqueeze(1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The sequence lengths (seq_lens and seq_lens_cpu) are not updated to reflect the truncated block tables for sliding window attention. This can lead to incorrect attention computation or out-of-bounds memory access in the FlashInfer kernel because the kernel will receive mismatched sequence lengths and KV blocks.

The sequence lengths should be reduced by the number of tokens in the skipped pages. Note that if common_attn_metadata is shared across layers, seq_lens and seq_lens_cpu should be cloned before modification to avoid side effects.

            decode_seq_lens_cpu = seq_lens_cpu[:num_decodes]
            num_skipped_pages = (
                torch.relu(decode_seq_lens_cpu - self.sliding_window) //
                page_size)

            skipped_tokens = num_skipped_pages * page_size
            seq_lens_cpu[:num_decodes] -= skipped_tokens
            seq_lens[:num_decodes] -= skipped_tokens.to(self.device,
                                                        non_blocking=True)

            block_table_bounds_cpu[:num_decodes] -= num_skipped_pages
            mask[:num_decodes] &= (arange[:num_decodes]
                                   >= num_skipped_pages.unsqueeze(1))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. To my understanding, it is not necessary to update seq_lens_cpu in this case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To my understanding, it is not necessary to update seq_lens_cpu in this case.

This is true for the "standard" FlashInfer code path which infers seqlen from paged_kv_last_page_len and paged_kv_indptr (so manipulating block_table_bounds_cpu is sufficient). Im not sure if this true for the new trtllm codepaths (i.e. trtllm_batch_decode_with_kv_cache) maybe someone from Nvidia can weigh in, @kushanam @pavanimajety

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 15, 2025
@mergify
Copy link

mergify bot commented Aug 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @WoosukKwon.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 20, 2025
yzh119 pushed a commit to flashinfer-ai/flashinfer that referenced this pull request Sep 11, 2025
…ix __syncthreads in mergestate (#1661)

<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
This PR skips redundant kv-tiles loading&computation in FA2's sliding
window implementation, which is used by GPT-OSS inference scenarios.
Main modification:
1. Make `PrefillPlanSM80` aware of sliding window: skipping redundant
attention states merge
2. Add `kv-start-idx` offsets at the start of each device kernel, to
skip the window computation. As it maintains the original 0-base
`kv-tile-idx`, it should be compatible w/ jit `sink-token`
implementation.
<img width="717" height="320" alt="image"
src="https://github.com/user-attachments/assets/a061c602-2cf9-4867-ad1b-46905ab1f44d"
/>

## 🔍 Related Issues
This should backup vllm-project/vllm#23010.

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

---------

Co-authored-by: happierpig <zhaoyilong217@sjtu.edn.cn>
@fadara01
Copy link
Contributor

Am I right in thinking that the sliding window attention in this case will be page/block granular, not token granular?
Are there any effects for this on accuracy? I'm asking because I'm thinking of using a similar strategy for initial enablement of SWA for CPU path (with cpu_attn.py which uses CPU's impl of paged_attention_v1 for decode)

fadara01 added a commit to fadara01/vllm that referenced this pull request Oct 28, 2025
- Enables Sliding Window Attention (SWA) for the CPU backend.
- Enables models with hybrid local-global attention (i.e., those with multiple KV cache groups) for the CPU backend.
- Fixes a bug in `TorchSDPAMetadataBuilderV1` where `query_start_loc_cpu` was updated in-place despite being shared between multiple KV cache groups.
- Adds a unit test for the decode phase in the CPU attention backend.
- Adds reference implementations used to test paged attention to a common utils module and reuses it for the decode attention CPU test.

SWA is enabled by truncating full blocks/pages that are outside the window (similar to vllm-project#23010).
Hence, the implementation is block-granular and may attend to at most (block_size - 1) extra tokens; this did not affect the accuracy tests I ran.
This is the best we can do with the current paged-attention API/implementation, which does not support SWA natively
See: vllm-project#4768 for context and the attempt to enable native SWA support in paged_attn.

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
fadara01 added a commit to fadara01/vllm that referenced this pull request Oct 28, 2025
- Enables Sliding Window Attention (SWA) for the CPU backend.
- Enables models with hybrid local-global attention (i.e., those with multiple KV cache groups) for the CPU backend.
- Fixes a bug in `TorchSDPAMetadataBuilderV1` where `query_start_loc_cpu` was updated in-place despite being shared between multiple KV cache groups.
- Adds a unit test for the decode phase in the CPU attention backend.
- Adds reference implementations used to test paged attention to a common utils module and reuses it for the decode attention CPU test.

SWA is enabled by truncating full blocks/pages that are outside the window (similar to vllm-project#23010).
Hence, the implementation is block-granular and may attend to at most (block_size - 1) extra tokens; this did not affect the accuracy tests I ran.
This is the best we can do with the current paged-attention API/implementation, which does not support SWA natively
See: vllm-project#4768 for context and the attempt to enable native SWA support in paged_attn.

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
fadara01 added a commit to fadara01/vllm that referenced this pull request Oct 28, 2025
- Enables Sliding Window Attention (SWA) for the CPU backend.
- Enables models with hybrid local-global attention (i.e., those with multiple KV cache groups) for the CPU backend.
- Fixes a bug in `TorchSDPAMetadataBuilderV1` where `query_start_loc_cpu` was updated in-place despite being shared between multiple KV cache groups.
- Adds a unit test for the decode phase in the CPU attention backend.
- Adds reference implementations used to test paged attention to a common utils module and reuses it for the decode attention CPU test.

SWA is enabled by truncating full blocks/pages that are outside the window (similar to vllm-project#23010).
Hence, the implementation is block-granular and may attend to at most (block_size - 1) extra tokens; this did not affect the accuracy tests I ran.
This is the best we can do with the current paged-attention API/implementation, which does not support SWA natively
See: vllm-project#4768 for context and the attempt to enable native SWA support in paged_attn.

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
fadara01 added a commit to fadara01/vllm that referenced this pull request Oct 28, 2025
- Enables Sliding Window Attention (SWA) for the CPU backend.
- Enables models with hybrid local-global attention (i.e., those with multiple KV cache groups) for the CPU backend.
- Fixes a bug in `TorchSDPAMetadataBuilderV1` where `query_start_loc_cpu` was updated in-place despite being shared between multiple KV cache groups.
- Adds a unit test for the decode phase in the CPU attention backend.
- Adds reference implementations used to test paged attention to a common utils module and reuses it for the decode attention CPU test.

SWA is enabled by truncating full blocks/pages that are outside the window (similar to vllm-project#23010).
Hence, the implementation is block-granular and may attend to at most (block_size - 1) extra tokens; this did not affect the accuracy tests I ran.
This is the best we can do with the current paged-attention API/implementation, which does not support SWA natively
See: vllm-project#4768 for context and the attempt to enable native SWA support in paged_attn.

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
fadara01 added a commit to fadara01/vllm that referenced this pull request Oct 28, 2025
- Enables Sliding Window Attention (SWA) for the CPU backend.
- Enables models with hybrid local-global attention (i.e., those with multiple KV cache groups) for the CPU backend.
- Fixes a bug in `TorchSDPAMetadataBuilderV1` where `query_start_loc_cpu` was updated in-place despite being shared between multiple KV cache groups.
- Adds a unit test for the decode phase in the CPU attention backend.
- Adds reference implementations used to test paged attention to a common utils module and reuses it for the decode attention CPU test.

SWA is enabled by truncating full blocks/pages that are outside the window (similar to vllm-project#23010).
Hence, the implementation is block-granular and may attend to at most (block_size - 1) extra tokens; this did not affect the accuracy tests I ran.
This is the best we can do with the current paged-attention API/implementation, which does not support SWA natively
See: vllm-project#4768 for context and the attempt to enable native SWA support in paged_attn.

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
@mergify mergify bot added the nvidia label Nov 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

4 participants