Skip to content

Commit 9c6c81d

Browse files
author
qqma
committed
get rid of _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH and extend to non-MLA flash attention
Signed-off-by: qqma <qqma@amazon.com>
1 parent 982937a commit 9c6c81d

File tree

2 files changed

+5
-17
lines changed

2 files changed

+5
-17
lines changed

vllm/v1/attention/backends/flash_attn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import torch
99

10+
from vllm import envs
1011
from vllm import _custom_ops as ops
1112
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1213
AttentionMetadata, AttentionType,
@@ -33,9 +34,6 @@
3334

3435
logger = init_logger(__name__)
3536

36-
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
37-
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
38-
3937

4038
class FlashAttentionBackend(AttentionBackend):
4139

@@ -215,7 +213,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
215213
# When using cuda graph, we need to set the upper bound of the
216214
# number of splits so that large enough intermediate buffers are
217215
# pre-allocated during capture.
218-
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
216+
self.max_num_splits = (
217+
envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
219218

220219
# Sliding window size to be used with the AOT scheduler will be
221220
# populated on first build() call.

vllm/v1/attention/backends/mla/flashattn_mla.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@
2525

2626
logger = init_logger(__name__)
2727

28-
# NOTE(matt): This is an arbitrary number, copied from
29-
# woosuk's implementation in standard FlashAttention backend
30-
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
31-
3228

3329
class FlashAttnMLABackend(MLACommonBackend):
3430

@@ -98,15 +94,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
9894
# When using cuda graph, we need to set the upper bound of the
9995
# number of splits so that large enough intermediate buffers are
10096
# pre-allocated during capture.
101-
if envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH is not None:
102-
logger.info_once(
103-
"Getting flash attention max num splits for "
104-
"cuda graph from environment variable, value=%s",
105-
envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
106-
self.max_num_splits = (
107-
envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
108-
else:
109-
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
97+
self.max_num_splits = (
98+
envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
11099

111100
# TODO(lucas): Until we add support for the DCP custom masking we need
112101
# to restrict decodes to q_len == 1 when DCP is enabled.

0 commit comments

Comments
 (0)