Skip to content

Commit 6222e75

Browse files
fadara01albertoperdomo2
authored andcommitted
[fix][cpu] fix prefill attention in CPU attention backend (vllm-project#27035)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
1 parent da68fff commit 6222e75

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

vllm/engine/arg_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1293,7 +1293,8 @@ def create_engine_config(
12931293

12941294
# Set default arguments for V1 Engine.
12951295
self._set_default_args(usage_context, model_config)
1296-
# Disable chunked prefill for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
1296+
# Disable chunked prefill and prefix caching for:
1297+
# POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
12971298
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
12981299
CpuArchEnum.POWERPC,
12991300
CpuArchEnum.S390X,
@@ -1306,6 +1307,13 @@ def create_engine_config(
13061307
"disabling it for V1 backend."
13071308
)
13081309
self.enable_chunked_prefill = False
1310+
logger.info(
1311+
"Prefix caching is not supported for ARM and POWER, "
1312+
"S390X and RISC-V CPUs; "
1313+
"disabling it for V1 backend."
1314+
)
1315+
self.enable_prefix_caching = False
1316+
13091317
assert self.enable_chunked_prefill is not None
13101318

13111319
sliding_window: int | None = None

vllm/v1/attention/backends/cpu_attn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def build(
412412
num_decode_tokens=num_decode_tokens,
413413
slot_mapping=slot_mapping,
414414
# to ensure inference when chunked_prefill is disabled
415-
seq_lens=seq_lens_cpu.tolist(),
415+
seq_lens=seq_lens_cpu.tolist()[num_decodes:], # prefill
416416
decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode
417417
decode_max_seq_len=max_decode_seq_len, # decode
418418
decode_block_tables=block_table_tensor[:num_decodes], # decode
@@ -617,7 +617,6 @@ def forward(
617617
prefill_meta.prefill_block_tables,
618618
self.alibi_slopes,
619619
)
620-
621620
if decode_meta := attn_metadata.decode_metadata:
622621
assert attn_type != AttentionType.ENCODER_ONLY, (
623622
"Encoder-only models should not have decode metadata."
@@ -686,7 +685,12 @@ def _run_sdpa_forward(
686685
causal_attn = attn_type == AttentionType.DECODER
687686

688687
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
689-
start_q, start_kv = 0, 0
688+
# Incoming Q and KV contain decoded tokens as well, hence start at an offset
689+
# equal to num_decode_tokens since decode requests appear first
690+
start_q, start_kv = (
691+
attn_metadata.num_decode_tokens,
692+
attn_metadata.num_decode_tokens,
693+
)
690694
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, attn_masks):
691695
end_q = start_q + seq_len_q
692696
end_kv = start_kv + seq_len_kv

0 commit comments

Comments
 (0)