Skip to content

Commit 868f3de

Browse files
committed
add sinks attn unit tests
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
1 parent b18c17f commit 868f3de

File tree

1 file changed

+76
-41
lines changed

1 file changed

+76
-41
lines changed

tests/kernels/attention/test_flashinfer_trtllm_attention.py

Lines changed: 76 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
4949
BLOCK_SIZE = [16]
5050
WINDOW_LEFT = [-1, 127]
5151
SOFT_CAP = [None, 50.0]
52+
HAS_SINKS = [True, False]
5253

5354
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
5455

@@ -63,6 +64,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
6364
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
6465
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
6566
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
67+
@pytest.mark.parametrize("has_sinks", HAS_SINKS)
6668
@torch.inference_mode
6769
def test_flashinfer_trtllm_decode_with_baseline(
6870
dtype: torch.dtype,
@@ -77,9 +79,10 @@ def test_flashinfer_trtllm_decode_with_baseline(
7779
block_size: int,
7880
window_left: int,
7981
soft_cap: Optional[float],
82+
has_sinks: bool,
8083
) -> None:
8184
torch.set_default_device("cuda")
82-
current_platform.seed_everything(0)
85+
current_platform.seed_everything(42)
8386

8487
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
8588
q_quant_dtype = q_quant_dtype or dtype
@@ -101,7 +104,16 @@ def test_flashinfer_trtllm_decode_with_baseline(
101104
else:
102105
raise ValueError(f"Invalid kv_layout: {kv_layout}")
103106

104-
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
107+
# max_q_len = 1
108+
q_lens = torch.ones((batch_size,), dtype=torch.int32)
109+
q_indptr = torch.cat(
110+
[
111+
torch.tensor([0], dtype=torch.int32),
112+
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
113+
]
114+
)
115+
116+
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
105117
if q_quant_dtype == FP8_DTYPE:
106118
query, q_scale = to_float8(query)
107119
ref_query = query.to(dtype) * q_scale
@@ -112,7 +124,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
112124
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
113125
kv_lens[-1] = max_kv_len
114126

115-
seq_lens = kv_lens
127+
seq_lens = kv_lens + q_lens
116128
max_seq_len = torch.max(seq_lens).item()
117129

118130
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
@@ -148,27 +160,36 @@ def test_flashinfer_trtllm_decode_with_baseline(
148160
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
149161

150162
# Baseline Decode
151-
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
152-
workspace_buffer, kv_layout, use_tensor_cores=True
153-
)
163+
if has_sinks:
164+
sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5
165+
wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper(
166+
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
167+
)
168+
else:
169+
sinks = None
170+
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
171+
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
172+
)
173+
154174
wrapper.plan(
155-
kv_indptr,
156-
kv_indices,
157-
kv_last_page_lens,
158-
num_qo_heads,
159-
num_kv_heads,
160-
head_size,
161-
block_size,
162-
"NONE",
175+
qo_indptr=q_indptr,
176+
paged_kv_indptr=kv_indptr,
177+
paged_kv_indices=kv_indices,
178+
paged_kv_last_page_len=kv_last_page_lens,
179+
num_qo_heads=num_qo_heads,
180+
num_kv_heads=num_kv_heads,
181+
head_dim_qk=head_size,
182+
page_size=block_size,
183+
causal=True,
163184
sm_scale=sm_scale,
164-
q_data_type=dtype,
165-
kv_data_type=dtype,
166185
window_left=window_left,
167186
logits_soft_cap=soft_cap,
187+
q_data_type=dtype,
188+
kv_data_type=dtype,
168189
)
169-
170190
output = torch.empty(ref_query.shape, dtype=dtype)
171-
wrapper.run(ref_query, ref_kv_cache, out=output)
191+
wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output)
192+
172193
o_scale = 1.0
173194
o_sf_scale_float = None
174195
if o_quant_dtype == FP8_DTYPE:
@@ -202,6 +223,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
202223
bmm1_scale=q_scale * k_scale * sm_scale,
203224
bmm2_scale=v_scale / o_scale,
204225
window_left=window_left,
226+
sinks=sinks,
205227
o_sf_scale=o_sf_scale_float,
206228
out=output_trtllm,
207229
)
@@ -217,11 +239,13 @@ def test_flashinfer_trtllm_decode_with_baseline(
217239
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
218240

219241
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
220-
rtol, atol = 3e-1, 1e0
242+
rtol, atol = 7e-2, 9e-2
221243
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
222-
rtol, atol = 5e-2, 7e-2
223-
else:
244+
rtol, atol = 2e-2, 4e-2
245+
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
224246
rtol, atol = 1e-2, 2e-2
247+
else:
248+
rtol, atol = 1e-2, 1e-2
225249

226250
(
227251
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
@@ -239,6 +263,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
239263
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
240264
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
241265
@pytest.mark.parametrize("soft_cap", [None])
266+
@pytest.mark.parametrize("has_sinks", HAS_SINKS)
242267
@torch.inference_mode
243268
def test_flashinfer_trtllm_prefill_with_baseline(
244269
dtype: torch.dtype,
@@ -253,9 +278,10 @@ def test_flashinfer_trtllm_prefill_with_baseline(
253278
block_size: int,
254279
window_left: int,
255280
soft_cap: Optional[float],
281+
has_sinks: bool,
256282
) -> None:
257283
torch.set_default_device("cuda")
258-
current_platform.seed_everything(0)
284+
current_platform.seed_everything(42)
259285

260286
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
261287
q_quant_dtype = q_quant_dtype or dtype
@@ -297,7 +323,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
297323
q_scale = 1.0
298324
ref_query = query
299325

300-
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
326+
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
301327
kv_lens[-1] = max_kv_len
302328

303329
seq_lens = kv_lens + q_lens
@@ -336,28 +362,36 @@ def test_flashinfer_trtllm_prefill_with_baseline(
336362
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
337363

338364
# Baseline Prefill
339-
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
340-
workspace_buffer, kv_layout
341-
)
365+
if has_sinks:
366+
sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5
367+
wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper(
368+
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
369+
)
370+
else:
371+
sinks = None
372+
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
373+
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
374+
)
375+
342376
wrapper.plan(
343-
q_indptr,
344-
kv_indptr,
345-
kv_indices,
346-
kv_last_page_lens,
347-
num_qo_heads,
348-
num_kv_heads,
349-
head_size,
350-
block_size,
377+
qo_indptr=q_indptr,
378+
paged_kv_indptr=kv_indptr,
379+
paged_kv_indices=kv_indices,
380+
paged_kv_last_page_len=kv_last_page_lens,
381+
num_qo_heads=num_qo_heads,
382+
num_kv_heads=num_kv_heads,
383+
head_dim_qk=head_size,
384+
page_size=block_size,
351385
causal=True,
352386
sm_scale=sm_scale,
353-
q_data_type=dtype,
354-
kv_data_type=dtype,
355387
window_left=window_left,
356388
logits_soft_cap=soft_cap,
389+
q_data_type=dtype,
390+
kv_data_type=dtype,
357391
)
358-
359392
output = torch.empty(ref_query.shape, dtype=dtype)
360-
wrapper.run(ref_query, ref_kv_cache, out=output)
393+
wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output)
394+
361395
o_scale = 1.0
362396
o_sf_scale_float = None
363397
if o_quant_dtype == FP8_DTYPE:
@@ -395,6 +429,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
395429
cum_seq_lens_q=q_indptr,
396430
cum_seq_lens_kv=kv_indptr,
397431
window_left=window_left,
432+
sinks=sinks,
398433
o_sf_scale=o_sf_scale_float,
399434
out=output_trtllm,
400435
)
@@ -410,11 +445,11 @@ def test_flashinfer_trtllm_prefill_with_baseline(
410445
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
411446

412447
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
413-
rtol, atol = 4e-1, 1e0
448+
rtol, atol = 1e-1, 2e-1
414449
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
415-
rtol, atol = 5e-2, 7e-2
416-
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
417450
rtol, atol = 4e-2, 6e-2
451+
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
452+
rtol, atol = 2e-2, 3e-2
418453
else:
419454
rtol, atol = 1e-2, 1e-2
420455

0 commit comments

Comments
 (0)