Skip to content

Conversation

@LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Jul 28, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Temporary fix for ChunkedLocalAttention when the hybrid kv-cache is disabled while we work on #21588

Test Plan

lm_eval

Test Result

Main - Hybrid KV-cache

python -m lm_eval --model vllm --model_args pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=4,gpu_memory_utilization=0.8,trust_remote_code=True,max_model_len=32768,disable_hybrid_kv_cache_manager=True --tasks ruler_qa_squad --limit 100 --batch_size auto --metadata='{"max_seq_lengths":[16384]}'
...
|    Tasks     |Version|Filter|n-shot|Metric|   | Value |   |Stderr|
|--------------|------:|------|-----:|-----:|---|------:|---|------|
|ruler_qa_squad|      1|none  |     0| 16384|↑  | 0.7092|±  |   N/A|
|              |       |none  |     0|  4096|↑  |-1.0000|±  |   N/A|

Main - No-Hybrid KV-cache

python -m lm_eval --model vllm --model_args pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=4,gpu_memory_utilization=0.8,trust_remote_code=True,max_model_len=32768,disable_hybrid_kv_cache_manager=True --tasks ruler_qa_squad --limit 100 --batch_size auto --metadata='{"max_seq_lengths":[16384]}'
...
|    Tasks     |Version|Filter|n-shot|Metric|   | Value |   |Stderr|
|--------------|------:|------|-----:|-----:|---|------:|---|------|
|ruler_qa_squad|      1|none  |     0| 16384|↑  | 0.6292|±  |   N/A|
|              |       |none  |     0|  4096|↑  |-1.0000|±  |   N/A|

This PR

python -m lm_eval --model vllm --model_args pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=4,gpu_memory_utilization=0.8,trust_remote_code=True,max_model_len=32768,disable_hybrid_kv_cache_manager=True --tasks ruler_qa_squad --limit 100 --batch_size auto --metadata='{"max_seq_lengths":[16384]}'
...
|    Tasks     |Version|Filter|n-shot|Metric|   | Value |   |Stderr|
|--------------|------:|------|-----:|-----:|---|------:|---|------|
|ruler_qa_squad|      1|none  |     0| 16384|↑  | 0.7058|±  |   N/A|
|              |       |none  |     0|  4096|↑  |-1.0000|±  |   N/A|

(Optional) Documentation Update

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jul 28, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR provides a temporary fix for an issue with chunked local attention when the hybrid KV cache is disabled, which seems to work based on the test results. However, I've found a critical issue in the implementation of the fix that could lead to incorrect behavior for models with multiple KV cache groups. The fix is applied inside a loop and could overwrite correct metadata with incorrect data depending on the processing order of the groups. I've provided a suggestion to correct this.

Comment on lines +817 to +835
if self.attention_chunk_size is not None \
and self.scheduler_config.disable_hybrid_kv_cache_manager:
if not hasattr(self, "local_attention_layers"):
self.local_attention_layers = []
attn_layers = get_layers_from_vllm_config(
self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if attn_module.use_irope:
self.local_attention_layers.append(layer_name)

local_attn_metadata_i = (builder.build(
common_prefix_len=0,
common_attn_metadata=make_local_attention_virtual_batches(
self.attention_chunk_size, common_attn_metadata,
self.cache_config.block_size),
))

for layer_name in self.local_attention_layers:
attn_metadata[layer_name] = local_attn_metadata_i
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This temporary fix has a potential correctness issue. It's placed inside the loop over kv_cache_groups. In each iteration, it rebuilds metadata and overwrites it for all layers in self.local_attention_layers.

If there are multiple kv_cache_groups (e.g., one for full attention and one for local attention), the metadata for local attention layers will be determined by the context of the last group processed in the loop. If the last group is not the local attention group, the metadata will be incorrect. This could lead to correctness issues depending on the model architecture and group processing order.

A safer approach is to scope this logic to only execute for the kv_cache_group that corresponds to chunked local attention. This also simplifies the code by removing the need to build and maintain self.local_attention_layers.

Here is a suggested implementation that addresses this:

Suggested change
if self.attention_chunk_size is not None \
and self.scheduler_config.disable_hybrid_kv_cache_manager:
if not hasattr(self, "local_attention_layers"):
self.local_attention_layers = []
attn_layers = get_layers_from_vllm_config(
self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if attn_module.use_irope:
self.local_attention_layers.append(layer_name)
local_attn_metadata_i = (builder.build(
common_prefix_len=0,
common_attn_metadata=make_local_attention_virtual_batches(
self.attention_chunk_size, common_attn_metadata,
self.cache_config.block_size),
))
for layer_name in self.local_attention_layers:
attn_metadata[layer_name] = local_attn_metadata_i
if (self.attention_chunk_size is not None
and self.scheduler_config.disable_hybrid_kv_cache_manager
and isinstance(kv_cache_group_spec.kv_cache_spec,
ChunkedLocalAttentionSpec)):
# This group is for chunked local attention.
# Rebuild its metadata with make_local_attention_virtual_batches
# and common_prefix_len=0.
local_attn_metadata_i = builder.build(
common_prefix_len=0,
common_attn_metadata=make_local_attention_virtual_batches(
self.attention_chunk_size, common_attn_metadata,
self.cache_config.block_size),
)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = local_attn_metadata_i

@heheda12345 heheda12345 enabled auto-merge (squash) July 28, 2025 04:28
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 28, 2025
@facebook-github-bot
Copy link

@yeqcharlotte has imported this pull request. If you are a Meta employee, you can view this in D79067990.

@heheda12345 heheda12345 merged commit 139a7f0 into vllm-project:main Jul 28, 2025
72 of 74 checks passed
liuyumoye pushed a commit to liuyumoye/vllm that referenced this pull request Jul 31, 2025
…ed (vllm-project#21707)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
HsChen-sys pushed a commit to HsChen-sys/vllm that referenced this pull request Aug 1, 2025
…ed (vllm-project#21707)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
…ed (vllm-project#21707)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: x22x22 <wadeking@qq.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
…ed (vllm-project#21707)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
…ed (vllm-project#21707)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…ed (vllm-project#21707)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…ed (vllm-project#21707)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…ed (vllm-project#21707)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
…ed (vllm-project#21707)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
…ed (vllm-project#21707)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants