Skip to content

Commit 3febecc

Browse files
Daisy-Ma-coderqqmaDarkLight1337
authored andcommitted
[CLI env var] Add VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH in env variables (vllm-project#25274)
Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
1 parent 6b2799e commit 3febecc

File tree

5 files changed

+32
-13
lines changed

5 files changed

+32
-13
lines changed

tests/compile/piecewise/test_full_cudagraph.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ class BackendConfig:
4646
# FA3 on Hopper
4747
"FA3":
4848
BackendConfig(name="FA3",
49-
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
49+
env_vars={
50+
"VLLM_FLASH_ATTN_VERSION": "3",
51+
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
52+
},
5053
comp_config={
5154
"cudagraph_mode": "FULL",
5255
},
@@ -66,6 +69,7 @@ class BackendConfig:
6669
BackendConfig(name="FlashAttentionMLA",
6770
env_vars={
6871
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
72+
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
6973
},
7074
comp_config={
7175
"cudagraph_mode": "FULL_DECODE_ONLY",
@@ -89,7 +93,10 @@ class BackendConfig:
8993
# FA2
9094
"FA2":
9195
BackendConfig(name="FA2",
92-
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
96+
env_vars={
97+
"VLLM_FLASH_ATTN_VERSION": "2",
98+
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
99+
},
93100
comp_config={
94101
"cudagraph_mode": "FULL",
95102
}),

tests/v1/cudagraph/test_cudagraph_mode.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ class BackendConfig:
4747
# FA3 on Hopper
4848
"FA3":
4949
BackendConfig(name="FA3",
50-
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
50+
env_vars={
51+
"VLLM_FLASH_ATTN_VERSION": "3",
52+
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
53+
},
5154
comp_config={
5255
"cudagraph_mode": "FULL",
5356
},
@@ -67,6 +70,7 @@ class BackendConfig:
6770
BackendConfig(name="FlashAttentionMLA",
6871
env_vars={
6972
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
73+
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
7074
},
7175
comp_config={
7276
"cudagraph_mode": "FULL_DECODE_ONLY",
@@ -75,7 +79,10 @@ class BackendConfig:
7579
# FA2
7680
"FA2":
7781
BackendConfig(name="FA2",
78-
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
82+
env_vars={
83+
"VLLM_FLASH_ATTN_VERSION": "2",
84+
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
85+
},
7986
comp_config={
8087
"cudagraph_mode": "FULL_AND_PIECEWISE",
8188
}),

vllm/envs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@
119119
VLLM_SERVER_DEV_MODE: bool = False
120120
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
121121
VLLM_MLA_DISABLE: bool = False
122+
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16
122123
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
123124
VLLM_RAY_BUNDLE_INDICES: str = ""
124125
VLLM_CUDART_SO_PATH: Optional[str] = None
@@ -946,6 +947,12 @@ def get_vllm_port() -> Optional[int]:
946947
"VLLM_MLA_DISABLE":
947948
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))),
948949

950+
# If set, vLLM will pick up the provided Flash Attention MLA
951+
# max number splits for cuda graph decode
952+
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH":
953+
lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
954+
"16")),
955+
949956
# Number of GPUs per worker in Ray, if it is set to be a fraction,
950957
# it allows ray to schedule multiple actors on a single GPU,
951958
# so that users can colocate other actors on the same GPUs as vLLM.
@@ -1379,6 +1386,7 @@ def compute_hash() -> str:
13791386
environment_variables_to_hash = [
13801387
"VLLM_PP_LAYER_PARTITION",
13811388
"VLLM_MLA_DISABLE",
1389+
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
13821390
"VLLM_USE_TRITON_FLASH_ATTN",
13831391
"VLLM_USE_TRITON_AWQ",
13841392
"VLLM_DP_RANK",

vllm/v1/attention/backends/flash_attn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99

1010
from vllm import _custom_ops as ops
11+
from vllm import envs
1112
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1213
AttentionMetadata, AttentionType,
1314
is_quantized_kv_cache)
@@ -33,9 +34,6 @@
3334

3435
logger = init_logger(__name__)
3536

36-
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
37-
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
38-
3937

4038
class FlashAttentionBackend(AttentionBackend):
4139

@@ -215,7 +213,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
215213
# When using cuda graph, we need to set the upper bound of the
216214
# number of splits so that large enough intermediate buffers are
217215
# pre-allocated during capture.
218-
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
216+
self.max_num_splits = (
217+
envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
219218

220219
# Sliding window size to be used with the AOT scheduler will be
221220
# populated on first build() call.

vllm/v1/attention/backends/mla/flashattn_mla.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88

9+
from vllm import envs
910
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
1011
is_quantized_kv_cache)
1112
from vllm.attention.utils.fa_utils import (flash_attn_supports_mla,
@@ -24,10 +25,6 @@
2425

2526
logger = init_logger(__name__)
2627

27-
# NOTE(matt): This is an arbitrary number, copied from
28-
# woosuk's implementation in standard FlashAttention backend
29-
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
30-
3128

3229
class FlashAttnMLABackend(MLACommonBackend):
3330

@@ -97,7 +94,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
9794
# When using cuda graph, we need to set the upper bound of the
9895
# number of splits so that large enough intermediate buffers are
9996
# pre-allocated during capture.
100-
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
97+
self.max_num_splits = (
98+
envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
10199

102100
# TODO(lucas): Until we add support for the DCP custom masking we need
103101
# to restrict decodes to q_len == 1 when DCP is enabled.

0 commit comments

Comments
 (0)