Skip to content

Commit eebd548

Browse files
LucasWilkinsondiegocastanibm
authored andcommitted
[BugFix] Fix ChunkedLocalAttention when the hybrid kv-cache is disabled (vllm-project#21707)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Diego-Castan <diego.castan@ibm.com>
1 parent 88fcf4e commit eebd548

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,31 @@ def _prepare_inputs(
809809
for layer_name in kv_cache_group_spec.layer_names:
810810
attn_metadata[layer_name] = attn_metadata_i
811811

812+
# Hack for now to fix chunked local attention + no hybrid kv cache
813+
# manager we can remove this once
814+
# https://github.com/vllm-project/vllm/pull/21588
815+
# is merged (i.e. properly handle different attention backends for
816+
# the same kv_cache_spec)
817+
if self.attention_chunk_size is not None \
818+
and self.scheduler_config.disable_hybrid_kv_cache_manager:
819+
if not hasattr(self, "local_attention_layers"):
820+
self.local_attention_layers = []
821+
attn_layers = get_layers_from_vllm_config(
822+
self.vllm_config, Attention)
823+
for layer_name, attn_module in attn_layers.items():
824+
if attn_module.use_irope:
825+
self.local_attention_layers.append(layer_name)
826+
827+
local_attn_metadata_i = (builder.build(
828+
common_prefix_len=0,
829+
common_attn_metadata=make_local_attention_virtual_batches(
830+
self.attention_chunk_size, common_attn_metadata,
831+
self.cache_config.block_size),
832+
))
833+
834+
for layer_name in self.local_attention_layers:
835+
attn_metadata[layer_name] = local_attn_metadata_i
836+
812837
attention_cuda_graphs = all(
813838
b.can_run_in_cudagraph(common_attn_metadata)
814839
for b in self.attn_metadata_builders)

0 commit comments

Comments
 (0)