Skip to content

Commit 1e8cef7

Browse files
LucasWilkinsonepwalsh
authored andcommitted
[Perf] Disable chunked local attention by default with llama4 (vllm-project#21761)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 00db788 commit 1e8cef7

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

vllm/config.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4769,12 +4769,23 @@ def __post_init__(self):
47694769
# Hybrid KV cache manager is not compatible with KV events.
47704770
self.scheduler_config.disable_hybrid_kv_cache_manager = True
47714771
if self.model_config is not None and \
4772-
self.model_config.attention_chunk_size is not None and \
4773-
self.speculative_config is not None and \
4774-
self.speculative_config.use_eagle():
4775-
# Hybrid KV cache manager is not yet supported with chunked
4776-
# local attention + eagle.
4777-
self.scheduler_config.disable_hybrid_kv_cache_manager = True
4772+
self.model_config.attention_chunk_size is not None:
4773+
if self.speculative_config is not None and \
4774+
self.speculative_config.use_eagle():
4775+
# Hybrid KV cache manager is not yet supported with chunked
4776+
# local attention + eagle.
4777+
self.scheduler_config.disable_hybrid_kv_cache_manager = True
4778+
elif \
4779+
not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
4780+
logger.warning(
4781+
"There is a latency regression when using chunked local"
4782+
" attention with the hybrid KV cache manager. Disabling"
4783+
" it, by default. To enable it, set the environment "
4784+
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
4785+
)
4786+
# Hybrid KV cache manager is not yet supported with chunked
4787+
# local attention.
4788+
self.scheduler_config.disable_hybrid_kv_cache_manager = True
47784789

47794790
def update_sizes_for_sequence_parallelism(self,
47804791
possible_sizes: list) -> list:

vllm/envs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@
143143
VLLM_USE_CUDNN_PREFILL: bool = False
144144
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
145145
VLLM_LOOPBACK_IP: str = ""
146+
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
146147

147148

148149
def get_default_cache_root():
@@ -991,6 +992,17 @@ def get_vllm_port() -> Optional[int]:
991992
# The default value is "VLLM".
992993
"VLLM_PROCESS_NAME_PREFIX":
993994
lambda: os.getenv("VLLM_PROCESS_NAME_PREFIX", "VLLM"),
995+
996+
# Allow chunked local attention with hybrid kv cache manager.
997+
# Currently using the Hybrid KV cache manager with chunked local attention
998+
# in the Llama4 models (the only models currently using chunked local attn)
999+
# causes a latency regression. For this reason, we disable it by default.
1000+
# This flag is used to allow users to enable it if they want to (to save on
1001+
# kv-cache memory usage and enable longer contexts)
1002+
# TODO(lucas): Remove this flag once latency regression is resolved.
1003+
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE":
1004+
lambda: bool(int(os.getenv(\
1005+
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "0"))),
9941006
}
9951007

9961008
# --8<-- [end:env-vars-definition]

0 commit comments

Comments
 (0)