Skip to content

Commit ac1611c

Browse files
committed
add sinks attn unit tests
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
1 parent 94fa69a commit ac1611c

File tree

2 files changed

+93
-50
lines changed

2 files changed

+93
-50
lines changed

tests/kernels/attention/test_flashinfer_trtllm_attention.py

Lines changed: 87 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
import pytest
77
import torch
88

9-
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
10-
FLOAT8_E4M3_MAX,
11-
dequantize_nvfp4_to_dtype)
9+
from tests.kernels.quantization.nvfp4_utils import (dequantize_nvfp4_to_dtype,
10+
get_nvfp4_global_scale)
1211
from vllm.platforms import current_platform
1312
from vllm.utils import round_up
1413

@@ -47,6 +46,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
4746
BLOCK_SIZE = [16]
4847
WINDOW_LEFT = [-1, 127]
4948
SOFT_CAP = [None, 50.0]
49+
HAS_SINKS = [True, False]
5050

5151
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
5252

@@ -61,6 +61,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
6161
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
6262
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
6363
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
64+
@pytest.mark.parametrize("has_sinks", HAS_SINKS)
6465
@torch.inference_mode
6566
def test_flashinfer_trtllm_decode_with_baseline(
6667
dtype: torch.dtype,
@@ -74,9 +75,10 @@ def test_flashinfer_trtllm_decode_with_baseline(
7475
block_size: int,
7576
window_left: int,
7677
soft_cap: Optional[float],
78+
has_sinks: bool,
7779
) -> None:
7880
torch.set_default_device("cuda")
79-
current_platform.seed_everything(0)
81+
current_platform.seed_everything(42)
8082

8183
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
8284
q_quant_dtype = q_quant_dtype or dtype
@@ -98,7 +100,17 @@ def test_flashinfer_trtllm_decode_with_baseline(
98100
else:
99101
raise ValueError(f"Invalid kv_layout: {kv_layout}")
100102

101-
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
103+
# max_q_len = 1
104+
q_lens = torch.ones((batch_size, ), dtype=torch.int32)
105+
q_indptr = torch.cat([
106+
torch.tensor([0], dtype=torch.int32),
107+
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
108+
])
109+
110+
query = torch.randn(torch.sum(q_lens).item(),
111+
num_qo_heads,
112+
head_size,
113+
dtype=dtype)
102114
if q_quant_dtype == FP8_DTYPE:
103115
query, q_scale = to_float8(query)
104116
ref_query = query.to(dtype) * q_scale
@@ -109,7 +121,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
109121
kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32)
110122
kv_lens[-1] = max_kv_len
111123

112-
seq_lens = kv_lens
124+
seq_lens = kv_lens + q_lens
113125
max_seq_len = torch.max(seq_lens).item()
114126

115127
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
@@ -146,31 +158,42 @@ def test_flashinfer_trtllm_decode_with_baseline(
146158
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
147159

148160
# Baseline Decode
149-
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
150-
workspace_buffer, kv_layout, use_tensor_cores=True)
151-
wrapper.plan(kv_indptr,
152-
kv_indices,
153-
kv_last_page_lens,
154-
num_qo_heads,
155-
num_kv_heads,
156-
head_size,
157-
block_size,
158-
"NONE",
161+
if has_sinks:
162+
sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5
163+
wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper(
164+
float_workspace_buffer=workspace_buffer,
165+
kv_layout=kv_layout,
166+
backend="fa2")
167+
else:
168+
sinks = None
169+
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
170+
float_workspace_buffer=workspace_buffer,
171+
kv_layout=kv_layout,
172+
backend="fa2")
173+
174+
wrapper.plan(qo_indptr=q_indptr,
175+
paged_kv_indptr=kv_indptr,
176+
paged_kv_indices=kv_indices,
177+
paged_kv_last_page_len=kv_last_page_lens,
178+
num_qo_heads=num_qo_heads,
179+
num_kv_heads=num_kv_heads,
180+
head_dim_qk=head_size,
181+
page_size=block_size,
182+
causal=True,
159183
sm_scale=sm_scale,
160-
q_data_type=dtype,
161-
kv_data_type=dtype,
162184
window_left=window_left,
163-
logits_soft_cap=soft_cap)
164-
185+
logits_soft_cap=soft_cap,
186+
q_data_type=dtype,
187+
kv_data_type=dtype)
165188
output = torch.empty(ref_query.shape, dtype=dtype)
166-
wrapper.run(ref_query, ref_kv_cache, out=output)
189+
wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output)
190+
167191
o_scale = 1.0
168192
o_sf_scale = None
169193
if o_quant_dtype == FP8_DTYPE:
170194
_, o_scale = to_float8(output)
171195
elif o_quant_dtype == FP4_DTYPE:
172-
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
173-
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
196+
o_sf_scale = get_nvfp4_global_scale(output)
174197

175198
# TRTLLM Decode
176199
if o_quant_dtype == FP4_DTYPE:
@@ -194,6 +217,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
194217
bmm1_scale=q_scale * k_scale * sm_scale,
195218
bmm2_scale=v_scale / o_scale,
196219
window_left=window_left,
220+
sinks=sinks,
197221
o_sf_scale=o_sf_scale,
198222
out=output_trtllm,
199223
)
@@ -210,11 +234,13 @@ def test_flashinfer_trtllm_decode_with_baseline(
210234
query.shape[2])
211235

212236
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
213-
rtol, atol = 3e-1, 1e0
237+
rtol, atol = 7e-2, 9e-2
214238
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
215-
rtol, atol = 5e-2, 7e-2
216-
else:
239+
rtol, atol = 2e-2, 4e-2
240+
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
217241
rtol, atol = 1e-2, 2e-2
242+
else:
243+
rtol, atol = 1e-2, 1e-2
218244

219245
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
220246
f"{torch.max(torch.abs(output - output_trtllm))}"
@@ -230,6 +256,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
230256
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
231257
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
232258
@pytest.mark.parametrize("soft_cap", [None])
259+
@pytest.mark.parametrize("has_sinks", HAS_SINKS)
233260
@torch.inference_mode
234261
def test_flashinfer_trtllm_prefill_with_baseline(
235262
dtype: torch.dtype,
@@ -243,9 +270,10 @@ def test_flashinfer_trtllm_prefill_with_baseline(
243270
block_size: int,
244271
window_left: int,
245272
soft_cap: Optional[float],
273+
has_sinks: bool,
246274
) -> None:
247275
torch.set_default_device("cuda")
248-
current_platform.seed_everything(0)
276+
current_platform.seed_everything(42)
249277

250278
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
251279
q_quant_dtype = q_quant_dtype or dtype
@@ -288,7 +316,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
288316
q_scale = 1.0
289317
ref_query = query
290318

291-
kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32)
319+
kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32)
292320
kv_lens[-1] = max_kv_len
293321

294322
seq_lens = kv_lens + q_lens
@@ -328,32 +356,42 @@ def test_flashinfer_trtllm_prefill_with_baseline(
328356
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
329357

330358
# Baseline Prefill
331-
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
332-
workspace_buffer, kv_layout)
333-
wrapper.plan(q_indptr,
334-
kv_indptr,
335-
kv_indices,
336-
kv_last_page_lens,
337-
num_qo_heads,
338-
num_kv_heads,
339-
head_size,
340-
block_size,
359+
if has_sinks:
360+
sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5
361+
wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper(
362+
float_workspace_buffer=workspace_buffer,
363+
kv_layout=kv_layout,
364+
backend="fa2")
365+
else:
366+
sinks = None
367+
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
368+
float_workspace_buffer=workspace_buffer,
369+
kv_layout=kv_layout,
370+
backend="fa2")
371+
372+
wrapper.plan(qo_indptr=q_indptr,
373+
paged_kv_indptr=kv_indptr,
374+
paged_kv_indices=kv_indices,
375+
paged_kv_last_page_len=kv_last_page_lens,
376+
num_qo_heads=num_qo_heads,
377+
num_kv_heads=num_kv_heads,
378+
head_dim_qk=head_size,
379+
page_size=block_size,
341380
causal=True,
342381
sm_scale=sm_scale,
343-
q_data_type=dtype,
344-
kv_data_type=dtype,
345382
window_left=window_left,
346-
logits_soft_cap=soft_cap)
347-
383+
logits_soft_cap=soft_cap,
384+
q_data_type=dtype,
385+
kv_data_type=dtype)
348386
output = torch.empty(ref_query.shape, dtype=dtype)
349-
wrapper.run(ref_query, ref_kv_cache, out=output)
387+
wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output)
388+
350389
o_scale = 1.0
351390
o_sf_scale = None
352391
if o_quant_dtype == FP8_DTYPE:
353392
_, o_scale = to_float8(output)
354393
elif o_quant_dtype == FP4_DTYPE:
355-
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
356-
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
394+
o_sf_scale = get_nvfp4_global_scale(output)
357395

358396
# TRTLLM Prefill
359397
if o_quant_dtype == FP4_DTYPE:
@@ -381,6 +419,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
381419
cum_seq_lens_q=q_indptr,
382420
cum_seq_lens_kv=kv_indptr,
383421
window_left=window_left,
422+
sinks=sinks,
384423
o_sf_scale=o_sf_scale,
385424
out=output_trtllm,
386425
)
@@ -397,11 +436,11 @@ def test_flashinfer_trtllm_prefill_with_baseline(
397436
query.shape[2])
398437

399438
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
400-
rtol, atol = 4e-1, 1e0
439+
rtol, atol = 1e-1, 2e-1
401440
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
402-
rtol, atol = 5e-2, 7e-2
403-
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
404441
rtol, atol = 4e-2, 6e-2
442+
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
443+
rtol, atol = 2e-2, 3e-2
405444
else:
406445
rtol, atol = 1e-2, 1e-2
407446

tests/kernels/quantization/nvfp4_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,12 @@ def break_fp4_bytes(a, dtype):
6868
return values.reshape(m, n * 2).to(dtype=dtype)
6969

7070

71+
def get_nvfp4_global_scale(a: torch.Tensor):
72+
return ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
73+
torch.abs(a).max().to(torch.float32))
74+
75+
7176
def quant_nvfp4_tensor(a: torch.Tensor):
72-
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
73-
torch.abs(a).max().to(torch.float32))
77+
a_global_scale = get_nvfp4_global_scale(a)
7478
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
7579
return a_quant, a_block_scale, a_global_scale

0 commit comments

Comments
 (0)