Skip to content

Commit 94f2d04

Browse files
committed
fix pre-commit
Signed-off-by: Huamin Li <3ericli@gmail.com>
1 parent 91864b7 commit 94f2d04

File tree

11 files changed

+44
-1
lines changed

11 files changed

+44
-1
lines changed

vllm/attention/backends/abstract.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,15 @@ def supports_sink(cls) -> bool:
142142
def is_sparse(cls) -> bool:
143143
return False
144144

145+
@classmethod
146+
def supports_attn_type(cls, attn_type: str) -> bool:
147+
"""Check if backend supports a given attention type.
148+
149+
By default, returns True (all types supported).
150+
Backends should override this to restrict supported types.
151+
"""
152+
return True
153+
145154
@classmethod
146155
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
147156
return True
@@ -171,6 +180,7 @@ def validate_configuration(
171180
has_sink: bool,
172181
use_sparse: bool,
173182
device_capability: "DeviceCapability",
183+
attn_type: str = AttentionType.DECODER,
174184
) -> list[str]:
175185
invalid_reasons = []
176186
if not cls.supports_head_size(head_size):
@@ -195,6 +205,8 @@ def validate_configuration(
195205
invalid_reasons.append("non-sparse not supported")
196206
if not cls.supports_compute_capability(device_capability):
197207
invalid_reasons.append("compute capability not supported")
208+
if not cls.supports_attn_type(attn_type):
209+
invalid_reasons.append(f"attention type {attn_type} not supported")
198210
combination_reason = cls.supports_combination(
199211
head_size,
200212
dtype,

vllm/attention/layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def __init__(
295295
block_size,
296296
use_mla=False,
297297
has_sink=self.has_sink,
298+
attn_type=attn_type,
298299
)
299300
else:
300301
self.attn_backend = attn_backend

vllm/attention/layers/encoder_only_attention.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@ def __init__(
7474
block_size = 16
7575

7676
underlying_attn_backend = get_attn_backend(
77-
head_size, dtype, kv_cache_dtype, block_size
77+
head_size,
78+
dtype,
79+
kv_cache_dtype,
80+
block_size,
81+
attn_type=AttentionType.ENCODER_ONLY,
7882
)
7983

8084
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)

vllm/attention/selector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def get_attn_backend(
7575
use_mla: bool = False,
7676
has_sink: bool = False,
7777
use_sparse: bool = False,
78+
attn_type: str | None = None,
7879
) -> type[AttentionBackend]:
7980
"""Selects which attention backend to use and lazily imports it."""
8081

@@ -93,6 +94,7 @@ def get_attn_backend(
9394
use_mla=use_mla,
9495
has_sink=has_sink,
9596
use_sparse=use_sparse,
97+
attn_type=attn_type,
9698
)
9799

98100

@@ -105,6 +107,7 @@ def _cached_get_attn_backend(
105107
use_mla: bool = False,
106108
has_sink: bool = False,
107109
use_sparse: bool = False,
110+
attn_type: str | None = None,
108111
) -> type[AttentionBackend]:
109112
# Check whether a particular choice of backend was
110113
# previously forced.
@@ -151,6 +154,7 @@ def _cached_get_attn_backend(
151154
use_mla,
152155
has_sink,
153156
use_sparse,
157+
attn_type,
154158
)
155159
if not attention_cls:
156160
raise ValueError(

vllm/platforms/cpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def get_attn_backend_cls(
135135
use_mla: bool,
136136
has_sink: bool,
137137
use_sparse: bool,
138+
attn_type: str | None = None,
138139
) -> str:
139140
from vllm.attention.backends.registry import AttentionBackendEnum
140141

vllm/platforms/cuda.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def get_valid_backends(
298298
has_sink,
299299
use_sparse,
300300
device_capability,
301+
attn_type,
301302
) -> tuple[
302303
list[tuple["AttentionBackendEnum", int]],
303304
dict["AttentionBackendEnum", list[str]],
@@ -318,6 +319,7 @@ def get_valid_backends(
318319
has_sink,
319320
use_sparse,
320321
device_capability,
322+
attn_type,
321323
)
322324
except ImportError:
323325
invalid_reasons_i = ["ImportError"]
@@ -340,13 +342,19 @@ def get_attn_backend_cls(
340342
use_mla: bool,
341343
has_sink: bool,
342344
use_sparse: bool,
345+
attn_type: str | None = None,
343346
) -> str:
344347
if not use_v1:
345348
raise RuntimeError(
346349
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
347350
"to select a supported backend."
348351
)
349352

353+
from vllm.attention import AttentionType
354+
355+
if attn_type is None:
356+
attn_type = AttentionType.DECODER
357+
350358
device_capability = cls.get_device_capability()
351359
assert device_capability is not None
352360

@@ -363,6 +371,7 @@ def get_attn_backend_cls(
363371
has_sink,
364372
use_sparse,
365373
device_capability,
374+
attn_type,
366375
)
367376
except ImportError:
368377
invalid_reasons = ["ImportError"]
@@ -386,6 +395,7 @@ def get_attn_backend_cls(
386395
has_sink,
387396
use_sparse,
388397
device_capability,
398+
attn_type,
389399
)
390400
reasons_str = (
391401
"{"

vllm/platforms/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def get_attn_backend_cls(
219219
use_mla: bool,
220220
has_sink: bool,
221221
use_sparse: bool,
222+
attn_type: str | None = None,
222223
) -> str:
223224
"""Get the attention backend class of a device."""
224225
return ""

vllm/platforms/rocm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def get_attn_backend_cls(
217217
use_mla,
218218
has_sink,
219219
use_sparse,
220+
attn_type: str | None = None,
220221
) -> str:
221222
from vllm._aiter_ops import rocm_aiter_ops
222223
from vllm.attention.backends.registry import AttentionBackendEnum

vllm/platforms/tpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def get_attn_backend_cls(
6262
use_mla: bool,
6363
has_sink,
6464
use_sparse,
65+
attn_type: str | None = None,
6566
) -> str:
6667
from vllm.attention.backends.registry import AttentionBackendEnum
6768

vllm/platforms/xpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def get_attn_backend_cls(
5252
use_mla: bool,
5353
has_sink: bool,
5454
use_sparse,
55+
attn_type: str | None = None,
5556
) -> str:
5657
from vllm.v1.attention.backends.utils import set_kv_cache_layout
5758

0 commit comments

Comments
 (0)