Skip to content

Conversation

@fadara01
Copy link
Contributor

@fadara01 fadara01 commented Oct 28, 2025

Purpose

  • Enables Sliding Window Attention (SWA) for the CPU backend.
  • Enables models with hybrid local-global attention (i.e., those with multiple KV cache groups) for the CPU backend.
  • Fixes a bug in TorchSDPAMetadataBuilderV1 where query_start_loc_cpu was updated in-place despite being shared between multiple KV cache groups.
  • Adds a unit test for the decode phase in the CPU attention backend.
  • Adds reference implementations used to test paged attention to a common utils module and reuses it for the decode attention CPU test.

SWA is enabled by truncating full blocks/pages that are outside the window (similar to #23010). Hence, the implementation is block-granular and may attend to at most (block_size - 1) extra tokens (but no degradation in generation quality for the tests I ran).
This is the best we can do with the current paged-attention API/implementation, which does not support SWA natively See: #4768 for context and the attempt to enable native SWA support in paged_attn.

Test Plan

  • Added a unit test for CPU backend decode attention (with multiple window values)
  • Ran end to end tests with google/gemma-2-2b-it (alternates between full attention and SWA with window=4096). The tests include prompts (e.g. asking the model to explain code) with lengths > and < window size (e.g. 4740, 6618, 6047, 25, 26) and generations of up to 1024 tokens.

Test Result

  • Unit test for CPU backed decode attention passes for all configurations
  • Eye-balled the end-to-end generations of the google/gemma-2-2b-it with vllm and made sure that all generations are meaningful and close enough to what one gets with transformers

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables Sliding Window Attention (SWA) and hybrid local-global attention models for the CPU backend. It also includes a bug fix for an in-place tensor modification and refactors attention test utilities into a common module. The changes are well-structured, and the addition of a new unit test for the CPU decode attention phase is a valuable improvement. The SWA implementation is block-granular, which is a reasonable approach given the current paged attention API. However, I've identified a critical bug in the refactored test utility code that needs to be addressed.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@fadara01 fadara01 force-pushed the swa_cpu_backend branch 2 times, most recently from 48fa62e to 89b58e3 Compare October 28, 2025 12:07
@fadara01
Copy link
Contributor Author

@codex review

@chatgpt-codex-connector
Copy link

💡 Codex Review

# full blocks to drop from the beginning
blocks_to_skip = torch.relu((context_lens - window) // block_size)
num_sequences, max_num_blocks_per_seq = block_table.shape

P0 Badge Avoid ReLU on integer tensors in SWA block truncation

The new _PagedAttention.truncate_blocks_for_swa computes blocks_to_skip = torch.relu((context_lens - window) // block_size) where context_lens comes from the decode metadata and is an int32 tensor. torch.relu does not support integer dtypes (RuntimeError: 'relu_cpu' not implemented for 'Int'), so any request whose context length exceeds the sliding window will crash before attention runs. Using torch.clamp_min(..., 0) or casting to a floating type avoids the failure and allows SWA on CPU to work.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

- Enables Sliding Window Attention (SWA) for the CPU backend.
- Enables models with hybrid local-global attention (i.e., those with multiple KV cache groups) for the CPU backend.
- Fixes a bug in `TorchSDPAMetadataBuilderV1` where `query_start_loc_cpu` was updated in-place despite being shared between multiple KV cache groups.
- Adds a unit test for the decode phase in the CPU attention backend.
- Adds reference implementations used to test paged attention to a common utils module and reuses it for the decode attention CPU test.

SWA is enabled by truncating full blocks/pages that are outside the window (similar to vllm-project#23010).
Hence, the implementation is block-granular and may attend to at most (block_size - 1) extra tokens; this did not affect the accuracy tests I ran.
This is the best we can do with the current paged-attention API/implementation, which does not support SWA natively
See: vllm-project#4768 for context and the attempt to enable native SWA support in paged_attn.

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
@bigPYJ1151
Copy link
Member

Oh. Thanks for the efforts. Looks like we have some overlaps. I'am also working on refactoring the CPU attention backend to clean up code and enable features including GQA, SWA, softcap, alibi, attention sink, chunked prefill togther. The code is hosted in vLLM so all arch can leverage it. see: https://github.com/bigPYJ1151/vllm/tree/new_attn

@fadara01
Copy link
Contributor Author

fadara01 commented Oct 29, 2025

@bigPYJ1151 thank you for your comment and the heads-up.
I think we should sync to avoid work duplication and help with this re-write since we have similar items on our roadmap.

w.r.t timelines, when do you expect your changes to make it into vllm?
We need to support SWA (and attention sinks which will be my next PR :D) in vllm relatively soon-ish (i.e. in the next two to three weeks).

@fadara01
Copy link
Contributor Author

fadara01 commented Nov 4, 2025

Closing in favor of #27954

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants