Skip to content

Conversation

@bigPYJ1151
Copy link
Member

@bigPYJ1151 bigPYJ1151 commented Nov 3, 2025

Purpose

This PR refactors CPU attention backend, includes:

  • clean up unused code and simplifiy metadata
  • a unified kernel supports chunked prefill, sliding window, alibi, softcap, and sink
  • renamed TorchSDPABackend to CPUAttentionBackend for less misunderstandings. For now the TORCH_SDPA tag is only used for ViT attention
  • better performance on both of prefill and decode
  • enable more related unit tests

cc @fadara01 @Akashcodes732

this:

============ Serving Benchmark Result ============
Successful requests:                     64        
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  119.44    
Total input tokens:                      65472     
Total generated tokens:                  65536     
Request throughput (req/s):              0.54      
Output token throughput (tok/s):         548.71    
Peak output token throughput (tok/s):    671.00    
Peak concurrent requests:                31.00     
Total Token throughput (tok/s):          1096.89   
---------------Time to First Token----------------
Mean TTFT (ms):                          2332.45   
Median TTFT (ms):                        2015.44   
P99 TTFT (ms):                           3497.95   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          26.90     
Median TPOT (ms):                        27.29     
P99 TPOT (ms):                           28.93     
---------------Inter-token Latency----------------
Mean ITL (ms):                           26.90     
Median ITL (ms):                         25.82     
P99 ITL (ms):                            27.93     
==================================================

main:

============ Serving Benchmark Result ============
Successful requests:                     64        
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  156.00    
Total input tokens:                      65472     
Total generated tokens:                  65536     
Request throughput (req/s):              0.41      
Output token throughput (tok/s):         420.10    
Peak output token throughput (tok/s):    512.00    
Peak concurrent requests:                32.00     
Total Token throughput (tok/s):          839.78    
---------------Time to First Token----------------
Mean TTFT (ms):                          2488.38   
Median TTFT (ms):                        2144.20   
P99 TTFT (ms):                           3710.32   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          35.68     
Median TPOT (ms):                        36.07     
P99 TPOT (ms):                           37.81     
---------------Inter-token Latency----------------
Mean ITL (ms):                           35.68     
Median ITL (ms):                         34.64     
P99 ITL (ms):                            38.18     
==================================================

Test Plan

unit tests

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant refactoring of the CPU attention backend, replacing the previous implementation with a new unified kernel. This new kernel adds support for features like sliding window, alibi, softcap, and sink, and includes optimizations for AMX BF16 instructions. The changes are extensive, touching build configurations, C++ kernels, Python backend logic, and tests. While the refactoring is a great improvement, I've identified a critical regression that breaks support for non-causal attention, which is necessary for encoder-decoder models. My review includes suggestions to address this issue.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@fadara01
Copy link
Contributor

fadara01 commented Nov 3, 2025

Amazing Work! Thanks for raising this :)
... I'll test and benchmark on Arm CPUs

@fadara01
Copy link
Contributor

fadara01 commented Nov 3, 2025

The changes this PR introduces are massive +3,932 −1,912
In the future, we should consider breaking massive PRs like this to multiple medium-sized PRs, such that they're "review-able". Not sure if that's possible here, given this is a complete re-write


self.sinks = sinks
if self.sinks is not None:
assert self.sinks.shape[0] == num_heads, (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When use_sdpa_prefill is true we use vanilla SDPA which does not support sinks.
Can we dispatch to cpu_attention_with_kv_cache when we have sinks even if use_sdpa_prefill is true?

If that's not possible for whatever reason, we should raise an error, and I can address it for the non-Intel CPU path in a follow-up PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems can't do this as we are unable to get sink config in builder. Just add a assertion for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's just fail for now, I'll address this in a follow up PR.

@@ -0,0 +1,31 @@
#ifndef SCRATCHPAD_MANAGER_H
Copy link
Contributor

@fadara01 fadara01 Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we leave oneDNN changes out? This PR is already too big and I don't think these changes are relevant the new attention backend?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is a bit unrelevant. But I think it is acceptable as just a few code😂

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha okay!!

s_aux=s_aux,
)

atol, rtol = 1.5e-2, 1e-2
Copy link
Contributor

@fadara01 fadara01 Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the absolute tolerance looks too high.
can we use:

from tests.kernels.allclose_default import get_default_atol, get_default_rtol
atol = get_default_atol(output)
rtol = get_default_rtol(output)

similar to what we do in https://github.com/vllm-project/vllm/blob/main/tests/kernels/attention/test_attention.py

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test file is based on https://github.com/vllm-project/vllm/blob/main/tests/kernels/attention/test_flash_attn.py

I think the abs tolerance looks more strict in test_attention.py is because the input is initialized with uniform_(-scale, scale), will be smaller compared with the inputs initialized withrandn in test_flash_attn.py.

But I perfer to use randn as I found using small inputs can't figure out value difference in test cases of sink sometimes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged, I wasn't aware that's the tolerance used for testing flash attention.

):
skip = True

# only tests features with bf16 to save time
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then why not just do QTYPES = [torch.bfloat16]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means we just test sink, alibi, softcap with bf16 as the logits processing is using fp32. For other cases all dtypes should be tested.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I missed the second line of the condition. I agree with you

"Qwen/Qwen3-8B", # qwen (text-only)
),
pytest.param("stabilityai/stablelm-3b-4e1t"), # stablelm
pytest.param("bigcode/starcoder2-3b"), # starcoder2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we enable a test for google/gemma-2-2b-it and mark it as cpu_model?
This would be a great end-to-end smoke test for SWA and Hybrid local-global attention models (with 2 kv cache groups)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, added a gemma-2 case.

" intel_extension_for_pytorch"
)
if cache_config.block_size % 32 != 0:
block_size = cache_config.block_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm in favor of just erroring out in this case saying that block_size needs to be divisible by 32 (instead of setting the block_size to a value that the user didn't choose)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes...But my concern is a lot of test cases use 16 by default and I don't want to add more if-else in different files so just round it here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that's a good point. I agree with you.

@fadara01
Copy link
Contributor

fadara01 commented Nov 4, 2025

I ran some end to end tests on Arm Neoverse-V2 with google/gemma-2-2b-it (which alternates between full attention and SWA with window=4096). The tests include prompts (e.g. asking the model to explain code) with lengths > and < window size (e.g. 4740, 6618, 6047, 25, 26) and generations of up to 1024 tokens.

I eye-balled the end-to-end generations with this new attention backend and can confirm that all generations are meaningful and close enough to what one gets with huggingface transformers

@fadara01
Copy link
Contributor

fadara01 commented Nov 4, 2025

I can also confirm that there's no perf regressions on Arm after running this benchmark:

VLLM_CPU_OMP_THREADS_BIND=0-63 LD_PRELOAD="/usr/lib/aarch64-linux-gnu/libtcmalloc_minimal.so.4:/usr/lib/aarch64-linux-gnu/libgomp.so.1" VLLM_TARGET_DEVICE=cpu VLLM_CPU_KVCACHE_SPACE=32 vllm bench throughput --num-prompts 128 --seed 0 --dataset-name sharegpt --max-model-len 4096 --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json  --model meta-llama/Llama-3.1-8B-Instruct --load-format dummy

int32_t reduction_split_num;
int32_t thread_num;
int32_t
effective_thread_num; // non-zero item num in cu_workitem_num_per_thread
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please add a comment explaining this?
what's the significance of the cu_ prefix here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh it is a mistake, should be workitem_num_per_thread. The cu means cummulation, and the cu_workitem_num_per_thread is a array contains prefix sum of workitem_num_per_thread.

@fadara01
Copy link
Contributor

fadara01 commented Nov 4, 2025

I got one test failure with one element miss-match while running tests/kernels/attention/test_cpu_attn.py on Arm

This is the configuration that fails:

seq_lens =  [(2345, 2345), (5, 5), (3, 16), (134, 5131)]
num_heads =  (9, 3)
head_size =  96
sliding_window =  None
dtype =  torch.bfloat16
block_size =  96
softcap =  None
num_blocks =  1024
use_alibi =  False
use_sink =  False
isa =  vec

And this is the log:

tests/kernels/attention/test_cpu_attn.py:413: AssertionError
FAILED tests/kernels/attention/test_cpu_attn.py::test_varlen_with_paged_kv[vec-False-False-1024-None-dtype0-None-96-96-num_heads2-seq_lens1] - AssertionError: Tensor-likes are not close!

Mismatched elements: 1 / 2148768 (0.0%)
Greatest absolute difference: 0.03125 at index (2356, 8, 34) (up to 0.015 allowed)
Greatest relative difference: 0.025634765625 at index (2356, 8, 34) (up to 0.01 allowed)
===== 1 failed, 161 passed, 1566 skipped, 2 warnings in 288.89s (0:04:48) ======

@bigPYJ1151 I'm happy to take a deeper look at this, unless you have hints on what might be the issue?

Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
@jikunshang jikunshang merged commit 7f829be into vllm-project:main Nov 12, 2025
61 checks passed
fangyuchu pushed a commit to fangyuchu/vllm that referenced this pull request Nov 12, 2025
Signed-off-by: jiang1.li <jiang1.li@intel.com>
geodavic pushed a commit to geodavic/vllm that referenced this pull request Nov 16, 2025
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: George D. Torres <gdavtor@gmail.com>
fadara01 added a commit to fadara01/vllm that referenced this pull request Nov 21, 2025
PR vllm-project#27954 added cpu_attention_with_kv_cache which supports chucked prefill, prefix caching,
SWA, alibi, softcap and sinks.

However, it's currently disabled for prefill on Arm CPUs because it's slower than torch.sdpa
for relatively long prefills. Hence chunked prefill, prefix caching, sinks, etc remained unsupported on Arm.

This PR accelerates cpu_attention_with_kv_cache on Arm CPUs by introducing NEON accelerated GEMMs
(enabled with ISA::NEON) for QK and PV. With the new GEMMs, performance of cpu_attention_with_kv_cache
is similar to torch.sdpa for long prefills, which allows us to enable cpu_attention_with_kv_cache for
prefill path on Arm and thus enable chunked prefill, prefix caching, sinks, alibi, softcap, etc.

Performance:

Uplift with ISA::NEON vs ISA::VEC:
For batch size = 64, query tokens = kv tokens = 512, q heads = 32, kv heads - 8, head size = 128, block size = 128:
using ISA::NEON for cpu_attention_with_kv_cache accelerates prefill attention by 2x compared to the current state with ISA::VEC

For the throughput benchmark below on Arm Neoverse-V2, using cpu_attention_with_kv_cache for prefills and decodes:
ISA::NEON yields ~ %13 higher throughput than ISA::VEC and similar throughput to using torch.sdpa for prefill.
```
export VLLM_CPU_OMP_THREADS_BIND=0-63
export LD_PRELOAD="/usr/lib/aarch64-linux-gnu/libtcmalloc_minimal.so.4:/usr/lib/aarch64-linux-gnu/libgomp.so.1"
export VLLM_TARGET_DEVICE=cpu
export VLLM_CPU_KVCACHE_SPACE=64
vllm bench throughput \
  --num-prompts 128 \
  --seed 0 \
  --dataset-name sharegpt \
  --input-len 1024 \
  --output-len 128 \
  --max-model-len 2048 \
  --max-num-batched-tokens 8192 \
  --model  meta-llama/Llama-3.1-8B-Instruct \
  --load-format dummy
```

Future PRs will accelerate attention further by introducing faster/vectorized exp implementations
and leveraging bfmmla/bfdot for QK, PV on Arm CPUs with bf16.

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
fadara01 added a commit to fadara01/vllm that referenced this pull request Nov 21, 2025
PR vllm-project#27954 added cpu_attention_with_kv_cache which supports chucked prefill, prefix caching,
SWA, alibi, softcap and sinks.

However, it's currently disabled for prefill on Arm CPUs because it's slower than torch.sdpa
for relatively long prefills. Hence chunked prefill, prefix caching, sinks, etc remained unsupported on Arm.

This PR accelerates cpu_attention_with_kv_cache on Arm CPUs by introducing NEON accelerated GEMMs
(enabled with ISA::NEON) for QK and PV. With the new GEMMs, performance of cpu_attention_with_kv_cache
is similar to torch.sdpa for long prefills, which allows us to enable cpu_attention_with_kv_cache for
prefill path on Arm and thus enable chunked prefill, prefix caching, sinks, alibi, softcap, etc.

Performance:

Uplift with ISA::NEON vs ISA::VEC:
For batch size = 64, query tokens = kv tokens = 512, q heads = 32, kv heads - 8, head size = 128, block size = 128:
using ISA::NEON for cpu_attention_with_kv_cache accelerates prefill attention by 2x compared to the current state with ISA::VEC

For the throughput benchmark below on Arm Neoverse-V2, using cpu_attention_with_kv_cache for prefills and decodes:
ISA::NEON yields ~ %13 higher throughput than ISA::VEC and similar throughput to using torch.sdpa for prefill.
```
export VLLM_CPU_OMP_THREADS_BIND=0-63
export LD_PRELOAD="/usr/lib/aarch64-linux-gnu/libtcmalloc_minimal.so.4:/usr/lib/aarch64-linux-gnu/libgomp.so.1"
export VLLM_TARGET_DEVICE=cpu
export VLLM_CPU_KVCACHE_SPACE=64
vllm bench throughput \
  --num-prompts 128 \
  --seed 0 \
  --dataset-name sharegpt \
  --input-len 1024 \
  --output-len 128 \
  --max-model-len 2048 \
  --max-num-batched-tokens 8192 \
  --model  meta-llama/Llama-3.1-8B-Instruct \
  --load-format dummy
```

Future PRs will accelerate attention further by introducing faster/vectorized exp implementations
and leveraging bfmmla/bfdot for QK, PV on Arm CPUs with bf16.

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
fadara01 added a commit to fadara01/vllm that referenced this pull request Nov 21, 2025
PR vllm-project#27954 added cpu_attention_with_kv_cache which supports chucked prefill, prefix caching,
SWA, alibi, softcap and sinks.

However, it's currently disabled for prefill on Arm CPUs because it's slower than torch.sdpa
for relatively long prefills. Hence chunked prefill, prefix caching, sinks, etc remained unsupported on Arm.

This PR accelerates cpu_attention_with_kv_cache on Arm CPUs by introducing NEON accelerated GEMMs
(enabled with ISA::NEON) for QK and PV. With the new GEMMs, performance of cpu_attention_with_kv_cache
is similar to torch.sdpa for long prefills, which allows us to enable cpu_attention_with_kv_cache for
prefill path on Arm and thus enable chunked prefill, prefix caching, sinks, alibi, softcap, etc.

Performance:

Uplift with ISA::NEON vs ISA::VEC:
For batch size = 64, query tokens = kv tokens = 512, q heads = 32, kv heads - 8, head size = 128, block size = 128:
using ISA::NEON for cpu_attention_with_kv_cache accelerates prefill attention by 2x compared to the current state with ISA::VEC

For the throughput benchmark below on Arm Neoverse-V2, using cpu_attention_with_kv_cache for prefills and decodes:
ISA::NEON yields ~ %13 higher throughput than ISA::VEC and similar throughput to using torch.sdpa for prefill.
```
export VLLM_CPU_OMP_THREADS_BIND=0-63
export LD_PRELOAD="/usr/lib/aarch64-linux-gnu/libtcmalloc_minimal.so.4:/usr/lib/aarch64-linux-gnu/libgomp.so.1"
export VLLM_TARGET_DEVICE=cpu
export VLLM_CPU_KVCACHE_SPACE=64
vllm bench throughput \
  --num-prompts 128 \
  --seed 0 \
  --dataset-name sharegpt \
  --input-len 1024 \
  --output-len 128 \
  --max-model-len 2048 \
  --max-num-batched-tokens 8192 \
  --model  meta-llama/Llama-3.1-8B-Instruct \
  --load-format dummy
```

Future PRs will accelerate attention further by introducing faster/vectorized exp implementations
and leveraging bfmmla/bfdot for QK, PV on Arm CPUs with bf16.

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
fadara01 added a commit to fadara01/vllm that referenced this pull request Nov 22, 2025
PR vllm-project#27954 added cpu_attention_with_kv_cache which supports chucked prefill, prefix caching,
SWA, alibi, softcap and sinks.

However, it's currently disabled for prefill on Arm CPUs because it's slower than torch.sdpa
for relatively long prefills. Hence chunked prefill, prefix caching, sinks, etc remained unsupported on Arm.

This PR accelerates cpu_attention_with_kv_cache on Arm CPUs by introducing NEON accelerated GEMMs
(enabled with ISA::NEON) for QK and PV. With the new GEMMs, performance of cpu_attention_with_kv_cache
is similar to torch.sdpa for long prefills, which allows us to enable cpu_attention_with_kv_cache for
prefill path on Arm and thus enable chunked prefill, prefix caching, sinks, alibi, softcap, etc.

Performance:

Uplift with ISA::NEON vs ISA::VEC:
For batch size = 64, query tokens = kv tokens = 512, q heads = 32, kv heads - 8, head size = 128, block size = 128:
using ISA::NEON for cpu_attention_with_kv_cache accelerates prefill attention by 2x compared to the current state with ISA::VEC

For the throughput benchmark below on Arm Neoverse-V2, using cpu_attention_with_kv_cache for prefills and decodes:
ISA::NEON yields ~ %13 higher throughput than ISA::VEC and similar throughput to using torch.sdpa for prefill.
```
export VLLM_CPU_OMP_THREADS_BIND=0-63
export LD_PRELOAD="/usr/lib/aarch64-linux-gnu/libtcmalloc_minimal.so.4:/usr/lib/aarch64-linux-gnu/libgomp.so.1"
export VLLM_TARGET_DEVICE=cpu
export VLLM_CPU_KVCACHE_SPACE=64
vllm bench throughput \
  --num-prompts 128 \
  --seed 0 \
  --dataset-name sharegpt \
  --input-len 1024 \
  --output-len 128 \
  --max-model-len 2048 \
  --max-num-batched-tokens 8192 \
  --model  meta-llama/Llama-3.1-8B-Instruct \
  --load-format dummy
```

Future PRs will accelerate attention further by introducing faster/vectorized exp implementations
and leveraging bfmmla/bfdot for QK, PV on Arm CPUs with bf16.

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: jiang1.li <jiang1.li@intel.com>
charlotte12l pushed a commit to charlotte12l/vllm that referenced this pull request Dec 5, 2025
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
@louie-tsai
Copy link
Contributor

@bigPYJ1151 possible to add the new build flag in the document ? https://docs.vllm.ai/en/latest/getting_started/installation/cpu/#build-image-from-source

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants