Skip to content

Commit 3f98da2

Browse files
aarnphmlk-chen
authored andcommitted
[Chore] added stubs for vllm_flash_attn during development mode (vllm-project#17228)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
1 parent ac4e84c commit 3f98da2

File tree

4 files changed

+269
-2
lines changed

4 files changed

+269
-2
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ ignore_patterns = [
5858
line-length = 80
5959
exclude = [
6060
# External file, leaving license intact
61-
"examples/other/fp8/quantizer/quantize.py"
61+
"examples/other/fp8/quantizer/quantize.py",
62+
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
6263
]
6364

6465
[tool.ruff.lint.per-file-ignores]

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,6 @@ def run(self) -> None:
378378
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
379379
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
380380
"vllm/vllm_flash_attn/flash_attn_interface.py",
381-
"vllm/vllm_flash_attn/__init__.py",
382381
"vllm/cumem_allocator.abi3.so",
383382
# "vllm/_version.py", # not available in nightly wheels yet
384383
]

vllm/vllm_flash_attn/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import importlib.metadata
4+
5+
try:
6+
__version__ = importlib.metadata.version("vllm-flash-attn")
7+
except importlib.metadata.PackageNotFoundError:
8+
# in this case, vllm-flash-attn is built from installing vllm editable
9+
__version__ = "0.0.0.dev0"
10+
11+
from .flash_attn_interface import (fa_version_unsupported_reason,
12+
flash_attn_varlen_func,
13+
flash_attn_with_kvcache,
14+
get_scheduler_metadata,
15+
is_fa_version_supported, sparse_attn_func,
16+
sparse_attn_varlen_func)
17+
18+
__all__ = [
19+
'flash_attn_varlen_func', 'flash_attn_with_kvcache',
20+
'get_scheduler_metadata', 'sparse_attn_func', 'sparse_attn_varlen_func',
21+
'is_fa_version_supported', 'fa_version_unsupported_reason'
22+
]
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
# ruff: ignore
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from typing import Any, Literal, overload
7+
8+
import torch
9+
10+
def get_scheduler_metadata(
11+
batch_size: int,
12+
max_seqlen_q: int,
13+
max_seqlen_k: int,
14+
num_heads_q: int,
15+
num_heads_kv: int,
16+
headdim: int,
17+
cache_seqlens: torch.Tensor,
18+
qkv_dtype: torch.dtype = ...,
19+
headdim_v: int | None = ...,
20+
cu_seqlens_q: torch.Tensor | None = ...,
21+
cu_seqlens_k_new: torch.Tensor | None = ...,
22+
cache_leftpad: torch.Tensor | None = ...,
23+
page_size: int = ...,
24+
max_seqlen_k_new: int = ...,
25+
causal: bool = ...,
26+
window_size: tuple[int, int] = ...,
27+
has_softcap: bool = ...,
28+
num_splits: int = ...,
29+
pack_gqa: Any | None = ...,
30+
sm_margin: int = ...,
31+
): ...
32+
@overload
33+
def flash_attn_varlen_func(
34+
q: tuple[int, int, int],
35+
k: tuple[int, int, int],
36+
v: tuple[int, int, int],
37+
max_seqlen_q: int,
38+
cu_seqlens_q: torch.Tensor | None,
39+
max_seqlen_k: int,
40+
cu_seqlens_k: torch.Tensor | None = ...,
41+
seqused_k: Any | None = ...,
42+
q_v: Any | None = ...,
43+
dropout_p: float = ...,
44+
causal: bool = ...,
45+
window_size: list[int] | None = ...,
46+
softmax_scale: float = ...,
47+
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
48+
deterministic: bool = ...,
49+
return_attn_probs: bool = ...,
50+
block_table: Any | None = ...,
51+
return_softmax_lse: Literal[False] = ...,
52+
out: Any = ...,
53+
# FA3 Only
54+
scheduler_metadata: Any | None = ...,
55+
q_descale: Any | None = ...,
56+
k_descale: Any | None = ...,
57+
v_descale: Any | None = ...,
58+
# Version selector
59+
fa_version: int = ...,
60+
) -> tuple[int, int, int]: ...
61+
@overload
62+
def flash_attn_varlen_func(
63+
q: tuple[int, int, int],
64+
k: tuple[int, int, int],
65+
v: tuple[int, int, int],
66+
max_seqlen_q: int,
67+
cu_seqlens_q: torch.Tensor | None,
68+
max_seqlen_k: int,
69+
cu_seqlens_k: torch.Tensor | None = ...,
70+
seqused_k: Any | None = ...,
71+
q_v: Any | None = ...,
72+
dropout_p: float = ...,
73+
causal: bool = ...,
74+
window_size: list[int] | None = ...,
75+
softmax_scale: float = ...,
76+
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
77+
deterministic: bool = ...,
78+
return_attn_probs: bool = ...,
79+
block_table: Any | None = ...,
80+
return_softmax_lse: Literal[True] = ...,
81+
out: Any = ...,
82+
# FA3 Only
83+
scheduler_metadata: Any | None = ...,
84+
q_descale: Any | None = ...,
85+
k_descale: Any | None = ...,
86+
v_descale: Any | None = ...,
87+
# Version selector
88+
fa_version: int = ...,
89+
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
90+
@overload
91+
def flash_attn_with_kvcache(
92+
q: tuple[int, int, int, int],
93+
k_cache: tuple[int, int, int, int],
94+
v_cache: tuple[int, int, int, int],
95+
k: tuple[int, int, int, int] | None = ...,
96+
v: tuple[int, int, int, int] | None = ...,
97+
rotary_cos: tuple[int, int] | None = ...,
98+
rotary_sin: tuple[int, int] | None = ...,
99+
cache_seqlens: int | torch.Tensor | None = None,
100+
cache_batch_idx: torch.Tensor | None = None,
101+
cache_leftpad: torch.Tensor | None = ...,
102+
block_table: torch.Tensor | None = ...,
103+
softmax_scale: float = ...,
104+
causal: bool = ...,
105+
window_size: tuple[int, int] = ..., # -1 means infinite context window
106+
softcap: float = ...,
107+
rotary_interleaved: bool = ...,
108+
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
109+
num_splits: int = ...,
110+
return_softmax_lse: Literal[False] = ...,
111+
*,
112+
out: Any = ...,
113+
# FA3 Only
114+
scheduler_metadata: Any | None = ...,
115+
q_descale: Any | None = ...,
116+
k_descale: Any | None = ...,
117+
v_descale: Any | None = ...,
118+
# Version selector
119+
fa_version: int = ...,
120+
) -> tuple[int, int, int, int]: ...
121+
@overload
122+
def flash_attn_with_kvcache(
123+
q: tuple[int, int, int, int],
124+
k_cache: tuple[int, int, int, int],
125+
v_cache: tuple[int, int, int, int],
126+
k: tuple[int, int, int, int] | None = ...,
127+
v: tuple[int, int, int, int] | None = ...,
128+
rotary_cos: tuple[int, int] | None = ...,
129+
rotary_sin: tuple[int, int] | None = ...,
130+
cache_seqlens: int | torch.Tensor | None = None,
131+
cache_batch_idx: torch.Tensor | None = None,
132+
cache_leftpad: torch.Tensor | None = ...,
133+
block_table: torch.Tensor | None = ...,
134+
softmax_scale: float = ...,
135+
causal: bool = ...,
136+
window_size: tuple[int, int] = ..., # -1 means infinite context window
137+
softcap: float = ...,
138+
rotary_interleaved: bool = ...,
139+
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
140+
num_splits: int = ...,
141+
return_softmax_lse: Literal[True] = ...,
142+
*,
143+
out: Any = ...,
144+
# FA3 Only
145+
scheduler_metadata: Any | None = ...,
146+
q_descale: Any | None = ...,
147+
k_descale: Any | None = ...,
148+
v_descale: Any | None = ...,
149+
# Version selector
150+
fa_version: int = ...,
151+
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
152+
@overload
153+
def sparse_attn_func(
154+
q: tuple[int, int, int, int],
155+
k: tuple[int, int, int, int],
156+
v: tuple[int, int, int, int],
157+
block_count: tuple[int, int, float],
158+
block_offset: tuple[int, int, float, int],
159+
column_count: tuple[int, int, float],
160+
column_index: tuple[int, int, float, int],
161+
dropout_p: float = ...,
162+
softmax_scale: float = ...,
163+
causal: bool = ...,
164+
softcap: float = ...,
165+
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
166+
deterministic: bool = ...,
167+
return_attn_probs: bool = ...,
168+
*,
169+
return_softmax_lse: Literal[False] = ...,
170+
out: Any = ...,
171+
) -> tuple[int, int, int]: ...
172+
@overload
173+
def sparse_attn_func(
174+
q: tuple[int, int, int, int],
175+
k: tuple[int, int, int, int],
176+
v: tuple[int, int, int, int],
177+
block_count: tuple[int, int, float],
178+
block_offset: tuple[int, int, float, int],
179+
column_count: tuple[int, int, float],
180+
column_index: tuple[int, int, float, int],
181+
dropout_p: float = ...,
182+
softmax_scale: float = ...,
183+
causal: bool = ...,
184+
softcap: float = ...,
185+
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
186+
deterministic: bool = ...,
187+
return_attn_probs: bool = ...,
188+
*,
189+
return_softmax_lse: Literal[True] = ...,
190+
out: Any = ...,
191+
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
192+
@overload
193+
def sparse_attn_varlen_func(
194+
q: tuple[int, int, int],
195+
k: tuple[int, int, int],
196+
v: tuple[int, int, int],
197+
block_count: tuple[int, int, float],
198+
block_offset: tuple[int, int, float, int],
199+
column_count: tuple[int, int, float],
200+
column_index: tuple[int, int, float, int],
201+
cu_seqlens_q: torch.Tensor | None,
202+
cu_seqlens_k: torch.Tensor | None,
203+
max_seqlen_q: int,
204+
max_seqlen_k: int,
205+
dropout_p: float = ...,
206+
softmax_scale: float = ...,
207+
causal: bool = ...,
208+
softcap: float = ...,
209+
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
210+
deterministic: bool = ...,
211+
return_attn_probs: bool = ...,
212+
*,
213+
return_softmax_lse: Literal[False] = ...,
214+
out: Any = ...,
215+
) -> tuple[int, int, int]: ...
216+
@overload
217+
def sparse_attn_varlen_func(
218+
q: tuple[int, int, int],
219+
k: tuple[int, int, int],
220+
v: tuple[int, int, int],
221+
block_count: tuple[int, int, float],
222+
block_offset: tuple[int, int, float, int],
223+
column_count: tuple[int, int, float],
224+
column_index: tuple[int, int, float, int],
225+
cu_seqlens_q: torch.Tensor | None,
226+
cu_seqlens_k: torch.Tensor | None,
227+
max_seqlen_q: int,
228+
max_seqlen_k: int,
229+
dropout_p: float = ...,
230+
softmax_scale: float = ...,
231+
causal: bool = ...,
232+
softcap: float = ...,
233+
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
234+
deterministic: bool = ...,
235+
return_attn_probs: bool = ...,
236+
*,
237+
return_softmax_lse: Literal[True] = ...,
238+
out: Any = ...,
239+
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
240+
def is_fa_version_supported(
241+
fa_version: int, device: torch.device | None = None
242+
) -> bool: ...
243+
def fa_version_unsupported_reason(
244+
fa_version: int, device: torch.device | None = None
245+
) -> str | None: ...

0 commit comments

Comments
 (0)