Skip to content

Commit b0955f8

Browse files
committed
add unit test
Signed-off-by: fsx950223 <fsx950223@outlook.com>
1 parent 23121d6 commit b0955f8

File tree

1 file changed

+202
-0
lines changed

1 file changed

+202
-0
lines changed
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import Optional
5+
6+
import pytest
7+
import torch
8+
9+
from vllm.platforms import current_platform
10+
11+
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
12+
HEAD_SIZES = [128, 256]
13+
BLOCK_SIZES = [16, 32]
14+
DTYPES = [torch.float16, torch.bfloat16]
15+
QDTYPES = [None]
16+
# one value large enough to test overflow in index calculation.
17+
# one value small enough to test the schema op check
18+
NUM_BLOCKS = [32768, 2048]
19+
20+
21+
def ref_paged_attn(
22+
query: torch.Tensor,
23+
key_cache: torch.Tensor,
24+
value_cache: torch.Tensor,
25+
query_lens: list[int],
26+
kv_lens: list[int],
27+
block_tables: torch.Tensor,
28+
scale: float,
29+
sliding_window: Optional[int] = None,
30+
soft_cap: Optional[float] = None,
31+
) -> torch.Tensor:
32+
num_seqs = len(query_lens)
33+
block_tables = block_tables.cpu().numpy()
34+
_, block_size, num_kv_heads, head_size = key_cache.shape
35+
36+
outputs: list[torch.Tensor] = []
37+
start_idx = 0
38+
for i in range(num_seqs):
39+
query_len = query_lens[i]
40+
kv_len = kv_lens[i]
41+
q = query[start_idx:start_idx + query_len]
42+
q *= scale
43+
44+
num_kv_blocks = (kv_len + block_size - 1) // block_size
45+
block_indices = block_tables[i, :num_kv_blocks]
46+
47+
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
48+
k = k[:kv_len]
49+
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
50+
v = v[:kv_len]
51+
52+
if q.shape[1] != k.shape[1]:
53+
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
54+
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
55+
attn = torch.einsum("qhd,khd->hqk", q, k).float()
56+
empty_mask = torch.ones(query_len, kv_len)
57+
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
58+
if sliding_window is not None:
59+
sliding_window_mask = torch.triu(empty_mask,
60+
diagonal=kv_len -
61+
(query_len + sliding_window) +
62+
1).bool().logical_not()
63+
mask |= sliding_window_mask
64+
if soft_cap is not None:
65+
attn = soft_cap * torch.tanh(attn / soft_cap)
66+
attn.masked_fill_(mask, float("-inf"))
67+
attn = torch.softmax(attn, dim=-1).to(v.dtype)
68+
out = torch.einsum("hqk,khd->qhd", attn, v)
69+
70+
outputs.append(out)
71+
start_idx += query_len
72+
73+
return torch.cat(outputs, dim=0)
74+
75+
76+
@pytest.mark.skipif(not current_platform.is_rocm(),
77+
reason="Only ROCm is supported")
78+
@pytest.mark.parametrize("seq_lens",
79+
[[(10, 1328), (5, 18),
80+
(129, 463)], [(8, 523), (24, 37), (3, 2011)]])
81+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
82+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
83+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
84+
@pytest.mark.parametrize("sliding_window", [None, 256])
85+
@pytest.mark.parametrize("dtype", DTYPES)
86+
@pytest.mark.parametrize("soft_cap", [None])
87+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
88+
@pytest.mark.parametrize("q_dtype", QDTYPES)
89+
@torch.inference_mode()
90+
def test_varlen_with_paged_kv(
91+
seq_lens: list[tuple[int, int]],
92+
num_heads: tuple[int, int],
93+
head_size: int,
94+
sliding_window: Optional[int],
95+
dtype: torch.dtype,
96+
block_size: int,
97+
soft_cap: Optional[float],
98+
num_blocks: int,
99+
q_dtype: Optional[torch.dtype],
100+
) -> None:
101+
torch.set_default_device("cuda")
102+
current_platform.seed_everything(0)
103+
num_seqs = len(seq_lens)
104+
query_lens = [x[0] for x in seq_lens]
105+
kv_lens = [x[1] for x in seq_lens]
106+
num_query_heads = num_heads[0]
107+
num_kv_heads = num_heads[1]
108+
assert num_query_heads % num_kv_heads == 0
109+
max_query_len = max(query_lens)
110+
max_kv_len = max(kv_lens)
111+
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
112+
(-1, -1))
113+
scale = head_size**-0.5
114+
115+
query = torch.randn(sum(query_lens),
116+
num_query_heads,
117+
head_size,
118+
dtype=dtype)
119+
key_cache = torch.randn(num_blocks,
120+
block_size,
121+
num_kv_heads,
122+
head_size,
123+
dtype=dtype)
124+
value_cache = torch.randn_like(key_cache)
125+
cu_query_lens = torch.tensor([0] + query_lens,
126+
dtype=torch.int32).cumsum(dim=0,
127+
dtype=torch.int32)
128+
129+
cu_seq_lens = torch.tensor([0] + kv_lens,
130+
dtype=torch.int32).cumsum(dim=0,
131+
dtype=torch.int32)
132+
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
133+
134+
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
135+
block_tables = torch.randint(0,
136+
num_blocks,
137+
(num_seqs, max_num_blocks_per_seq),
138+
dtype=torch.int32)
139+
140+
output = torch.empty_like(query)
141+
total_tokens = cu_seq_lens[-1].item()
142+
k_buffer = torch.empty(
143+
(total_tokens, num_kv_heads, head_size),
144+
dtype=dtype,
145+
device=torch.device("cuda"),
146+
)
147+
v_buffer = torch.empty(
148+
(total_tokens, num_kv_heads, head_size),
149+
dtype=dtype,
150+
device=torch.device("cuda"),
151+
)
152+
maybe_quantized_query = query
153+
maybe_quantized_key_cache = key_cache
154+
maybe_quantized_value_cache = value_cache
155+
k_descale = None
156+
v_descale = None
157+
if q_dtype is not None:
158+
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
159+
maybe_quantized_query = query.to(q_dtype)
160+
maybe_quantized_key_cache = key_cache.to(q_dtype)
161+
maybe_quantized_value_cache = value_cache.to(q_dtype)
162+
163+
scale_shape = (num_seqs, num_kv_heads)
164+
k_descale = torch.ones(scale_shape, dtype=torch.float32)
165+
v_descale = torch.ones(scale_shape, dtype=torch.float32)
166+
167+
torch.ops.vllm.flash_attn_varlen_func(
168+
maybe_quantized_query,
169+
maybe_quantized_key_cache,
170+
maybe_quantized_value_cache,
171+
k_buffer,
172+
v_buffer,
173+
out=output,
174+
cu_seqlens_q=cu_query_lens,
175+
max_seqlen_q=max_query_len,
176+
max_seqlen_k=max_kv_len,
177+
softmax_scale=scale,
178+
alibi_slopes=None,
179+
window_size=window_size,
180+
block_table=block_tables,
181+
cu_seqlens_k=cu_seq_lens,
182+
k_scale=k_descale,
183+
v_scale=v_descale,
184+
)
185+
186+
ref_output = ref_paged_attn(
187+
query=query,
188+
key_cache=key_cache,
189+
value_cache=value_cache,
190+
query_lens=query_lens,
191+
kv_lens=kv_lens,
192+
block_tables=block_tables,
193+
scale=scale,
194+
sliding_window=sliding_window,
195+
soft_cap=soft_cap,
196+
)
197+
198+
atol, rtol = 2e-2, 2e-2
199+
if q_dtype is not None:
200+
atol, rtol = 1.5e-1, 1.5e-1
201+
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
202+
f"{torch.max(torch.abs(output - ref_output))}"

0 commit comments

Comments
 (0)