Skip to content

Commit

Permalink
[Bugfix]: fix v1/v2 paged attention kernel unit test.
Browse files Browse the repository at this point in the history
  • Loading branch information
vllmellm committed Jan 23, 2025
1 parent b8e66a9 commit 0f6ff75
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes
from vllm.utils import get_max_shared_memory_bytes, is_navi

from .allclose_default import get_default_atol, get_default_rtol

Expand Down Expand Up @@ -122,7 +122,7 @@ def ref_single_query_cached_kv_attention(

@pytest.mark.parametrize(
"version",
["v1", "v2"] if not current_platform.is_rocm() else ["rocm"])
["v1", "v2"] if not current_platform.is_rocm() else ["v1","v2","rocm"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
Expand Down Expand Up @@ -189,6 +189,10 @@ def test_paged_attention(
# Using default kv_scale
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32)

# additional argument for v1/v2 pa kernel
num_threads = 1024 if current_platform.is_rocm() \
and not is_navi() else 128

# Call the paged attention kernel.
output = torch.empty_like(query)
if version == "v1":
Expand All @@ -212,7 +216,7 @@ def test_paged_attention(
opcheck(torch.ops._C.paged_attention_v1,
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, num_threads),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))

Expand Down Expand Up @@ -257,7 +261,7 @@ def test_paged_attention(
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, num_threads),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))

Expand Down

0 comments on commit 0f6ff75

Please sign in to comment.