Skip to content

Commit 0b73736

Browse files
authored
[Kernel] Raise verbose error and consolidate num_heads/num_kv_heads divisibility check (#19339)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
1 parent ee1531b commit 0b73736

File tree

17 files changed

+24
-19
lines changed

17 files changed

+24
-19
lines changed

tests/kernels/attention/test_attention.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
1111
from tests.kernels.utils import opcheck
1212
from vllm import _custom_ops as ops
13+
from vllm.attention.layer import Attention, MultiHeadAttention
1314
from vllm.platforms import current_platform
1415
from vllm.utils import get_max_shared_memory_bytes
1516

@@ -506,3 +507,18 @@ def test_multi_query_kv_attention_with_alibi(
506507
device,
507508
use_alibi=True,
508509
)
510+
511+
512+
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
513+
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
514+
head_size = 64
515+
scale = float(1.0 / (head_size**0.5))
516+
num_heads = 16
517+
num_kv_heads = 5
518+
with pytest.raises(AssertionError):
519+
_ = attention_cls(
520+
num_heads=num_heads,
521+
head_size=head_size,
522+
scale=scale,
523+
num_kv_heads=num_kv_heads,
524+
)

vllm/attention/backends/blocksparse_attn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def __post_init__(self):
6565
assert self.block_size > 0
6666
assert self.local_blocks >= 0
6767
assert self.vert_stride >= 1
68-
assert self.num_heads % self.num_kv_heads == 0
6968

7069
tp_size = get_tensor_model_parallel_world_size()
7170
tp_rank = get_tensor_model_parallel_rank()
@@ -329,9 +328,8 @@ def __init__(
329328
self.head_size = head_size
330329
self.scale = float(scale)
331330
self.alibi_slopes = alibi_slopes
332-
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
331+
self.num_kv_heads = num_kv_heads
333332

334-
assert self.num_heads % self.num_kv_heads == 0
335333
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
336334

337335
self.local_blocks = self.blocksparse_params.local_blocks

vllm/attention/backends/dual_chunk_flash_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,6 @@ def __init__(
307307
if sliding_window is not None else (-1, -1))
308308
self.kv_cache_dtype = kv_cache_dtype
309309

310-
assert self.num_heads % self.num_kv_heads == 0
311310
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
312311
if sliding_window is not None:
313312
# NOTE(woosuk): flash-attn's sliding window does not work with

vllm/attention/backends/flash_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,6 @@ def __init__(
654654
logits_soft_cap = 0
655655
self.logits_soft_cap = logits_soft_cap
656656

657-
assert self.num_heads % self.num_kv_heads == 0
658657
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
659658

660659
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()

vllm/attention/backends/flashinfer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,6 @@ def __init__(
957957
self.kv_cache_dtype = kv_cache_dtype
958958
self.logits_soft_cap = logits_soft_cap
959959

960-
assert self.num_heads % self.num_kv_heads == 0
961960
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
962961

963962
if attn_type != AttentionType.DECODER:

vllm/attention/backends/hpu_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ def __init__(
148148
alibi_slopes_tensor = torch.tensor(alibi_slopes,
149149
dtype=torch.bfloat16)
150150
self.alibi_slopes = alibi_slopes_tensor
151-
assert self.num_heads % self.num_kv_heads == 0
152151
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
153152

154153
if self.prefill_impl == 'fsdpa':

vllm/attention/backends/ipex_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def __init__(
145145
self.sliding_window = sliding_window
146146
self.kv_cache_dtype = kv_cache_dtype
147147

148-
assert self.num_heads % self.num_kv_heads == 0
149148
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
150149
self.need_mask = (self.sliding_window is not None)
151150
if logits_soft_cap is None:

vllm/attention/backends/pallas.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,8 @@ def __init__(
121121
self.num_heads = num_heads
122122
self.head_size = head_size
123123
self.scale = float(scale)
124-
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
124+
self.num_kv_heads = num_kv_heads
125125

126-
assert self.num_heads % self.num_kv_heads == 0
127126
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
128127
self.logits_soft_cap = logits_soft_cap
129128
if head_size % 128 != 0:

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,6 @@ def __init__(
528528
if sliding_window is not None else (-1, -1))
529529
self.kv_cache_dtype = kv_cache_dtype
530530

531-
assert self.num_heads % self.num_kv_heads == 0
532531
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
533532

534533
self.paged_attn_module = _get_paged_attn_module()

vllm/attention/backends/torch_sdpa.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,6 @@ def __init__(
433433
self.sliding_window = sliding_window
434434
self.kv_cache_dtype = kv_cache_dtype
435435

436-
assert self.num_heads % self.num_kv_heads == 0
437436
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
438437
self.need_mask = (self.alibi_slopes is not None
439438
or self.sliding_window is not None)

0 commit comments

Comments
 (0)