Skip to content

Commit 58f66f8

Browse files
committed
Add attention sink support to tensor core template for decode attention
Add maybe_s_aux support to prefill template used for decode attention when use_tensor_cores=True. Includes updates to params structures, variant handling, JIT generation, Python wrappers, and comprehensive test coverage with validation.
1 parent 99067e4 commit 58f66f8

File tree

6 files changed

+116
-5
lines changed

6 files changed

+116
-5
lines changed

flashinfer/decode.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ def single_decode_with_kv_cache(
388388
rope_scale: Optional[float] = None,
389389
rope_theta: Optional[float] = None,
390390
return_lse: Literal[True] = True,
391+
sinks: Optional[torch.Tensor] = None,
391392
) -> Tuple[torch.Tensor, torch.Tensor]: ...
392393

393394

@@ -407,6 +408,7 @@ def single_decode_with_kv_cache(
407408
rope_scale: Optional[float] = None,
408409
rope_theta: Optional[float] = None,
409410
return_lse: bool = False,
411+
sinks: Optional[torch.Tensor] = None,
410412
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
411413
r"""Decode attention with KV Cache for single request, return attention output.
412414
@@ -533,6 +535,7 @@ def single_decode_with_kv_cache(
533535
window_left,
534536
None, # packed_custom_mask
535537
_get_cache_alibi_slopes_buf(num_qo_heads, q.device),
538+
sinks, # maybe_s_aux
536539
logits_soft_cap,
537540
sm_scale,
538541
None, # scale_q, not supported yet

flashinfer/jit/attention/modules.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -467,8 +467,8 @@ def gen_single_decode_module(
467467
dtype_o,
468468
head_dim_qk,
469469
head_dim_vo,
470-
["maybe_alibi_slopes", "maybe_s_aux"], # additional_tensor_names
471-
["float", "float"], # additional_tensor_dtypes
470+
["maybe_alibi_slopes"], # additional_tensor_names
471+
["float"], # additional_tensor_dtypes
472472
[
473473
"logits_soft_cap",
474474
"sm_scale",
@@ -516,8 +516,8 @@ def gen_single_prefill_module(
516516

517517
if backend == "fa2":
518518
assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend"
519-
additional_tensor_names = ["maybe_custom_mask", "maybe_alibi_slopes"]
520-
additional_tensor_dtypes = ["uint8_t", "float"]
519+
additional_tensor_names = ["maybe_custom_mask", "maybe_alibi_slopes", "maybe_s_aux"]
520+
additional_tensor_dtypes = ["uint8_t", "float", "float"]
521521
additional_scalar_names = [
522522
"logits_soft_cap",
523523
"sm_scale",

flashinfer/prefill.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def run_single_prefill(
277277
window_left: int,
278278
maybe_packed_custom_mask: Optional[torch.Tensor],
279279
maybe_alibi_slopes: Optional[torch.Tensor],
280+
maybe_s_aux: Optional[torch.Tensor],
280281
logits_soft_cap: float,
281282
sm_scale: float,
282283
scale_q: Optional[torch.Tensor],
@@ -330,6 +331,7 @@ def run_single_prefill(
330331
window_left,
331332
maybe_packed_custom_mask,
332333
maybe_alibi_slopes,
334+
maybe_s_aux,
333335
logits_soft_cap,
334336
sm_scale,
335337
1.0 / rope_scale, # rope_rcp_scale
@@ -350,6 +352,7 @@ def _fake_run_single_prefill(
350352
window_left: int,
351353
maybe_packed_custom_mask: Optional[torch.Tensor],
352354
maybe_alibi_slopes: Optional[torch.Tensor],
355+
maybe_s_aux: Optional[torch.Tensor],
353356
logits_soft_cap: float,
354357
sm_scale: float,
355358
rope_scale: float,

include/flashinfer/attention/default_prefill_params.cuh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ struct SinglePrefillParams {
3838
DTypeO* o;
3939
float* lse;
4040
float* maybe_alibi_slopes;
41+
float* maybe_s_aux;
4142
uint_fastdiv group_size;
4243
uint32_t qo_len;
4344
uint32_t kv_len;
@@ -66,6 +67,7 @@ struct SinglePrefillParams {
6667
o(nullptr),
6768
lse(nullptr),
6869
maybe_alibi_slopes(nullptr),
70+
maybe_s_aux(nullptr),
6971
group_size(),
7072
qo_len(0),
7173
kv_len(0),
@@ -87,6 +89,7 @@ struct SinglePrefillParams {
8789

8890
__host__ SinglePrefillParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* maybe_custom_mask,
8991
DTypeO* o, float* lse, float* maybe_alibi_slopes,
92+
float* maybe_s_aux,
9093
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len,
9194
uint32_t kv_len, uint32_t q_stride_n, uint32_t q_stride_h,
9295
uint32_t kv_stride_n, uint32_t kv_stride_h, uint32_t head_dim,
@@ -99,6 +102,7 @@ struct SinglePrefillParams {
99102
o(o),
100103
lse(lse),
101104
maybe_alibi_slopes(maybe_alibi_slopes),
105+
maybe_s_aux(maybe_s_aux),
102106
group_size(num_qo_heads / num_kv_heads),
103107
num_qo_heads(num_qo_heads),
104108
num_kv_heads(num_kv_heads),
@@ -146,6 +150,7 @@ struct BatchPrefillRaggedParams {
146150
DTypeO* o;
147151
float* lse;
148152
float* maybe_alibi_slopes;
153+
float* maybe_s_aux;
149154
uint_fastdiv group_size;
150155
uint32_t num_qo_heads;
151156
uint32_t num_kv_heads;
@@ -190,6 +195,7 @@ struct BatchPrefillRaggedParams {
190195
o(nullptr),
191196
lse(nullptr),
192197
maybe_alibi_slopes(nullptr),
198+
maybe_s_aux(nullptr),
193199
group_size(),
194200
num_qo_heads(0),
195201
num_kv_heads(0),
@@ -224,6 +230,7 @@ struct BatchPrefillRaggedParams {
224230
IdType* q_indptr, IdType* kv_indptr, IdType* maybe_mask_indptr,
225231
IdType* maybe_q_rope_offset, IdType* maybe_k_rope_offset,
226232
DTypeO* o, float* lse, float* maybe_alibi_slopes,
233+
float* maybe_s_aux,
227234
uint32_t num_qo_heads, uint32_t num_kv_heads,
228235
uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n,
229236
uint32_t kv_stride_h, int32_t window_left,
@@ -241,6 +248,7 @@ struct BatchPrefillRaggedParams {
241248
o(o),
242249
lse(lse),
243250
maybe_alibi_slopes(maybe_alibi_slopes),
251+
maybe_s_aux(maybe_s_aux),
244252
group_size(num_qo_heads / num_kv_heads),
245253
num_qo_heads(num_qo_heads),
246254
num_kv_heads(num_kv_heads),
@@ -296,6 +304,7 @@ struct BatchPrefillPagedParams {
296304
DTypeO* o;
297305
float* lse;
298306
float* maybe_alibi_slopes;
307+
float* maybe_s_aux;
299308
uint_fastdiv group_size;
300309
uint32_t num_qo_heads;
301310
IdType q_stride_n;
@@ -332,6 +341,7 @@ struct BatchPrefillPagedParams {
332341
o(nullptr),
333342
lse(nullptr),
334343
maybe_alibi_slopes(nullptr),
344+
maybe_s_aux(nullptr),
335345
group_size(),
336346
num_qo_heads(0),
337347
q_stride_n(0),
@@ -361,6 +371,7 @@ struct BatchPrefillPagedParams {
361371
uint8_t* maybe_custom_mask, IdType* q_indptr,
362372
IdType* maybe_mask_indptr, IdType* maybe_q_rope_offset,
363373
DTypeO* o, float* lse, float* maybe_alibi_slopes,
374+
float* maybe_s_aux,
364375
uint32_t num_qo_heads, IdType q_stride_n, IdType q_stride_h,
365376
int32_t window_left, float logits_soft_cap, float sm_scale,
366377
float rope_scale, float rope_theta)
@@ -373,6 +384,7 @@ struct BatchPrefillPagedParams {
373384
o(o),
374385
lse(lse),
375386
maybe_alibi_slopes(maybe_alibi_slopes),
387+
maybe_s_aux(maybe_s_aux),
376388
group_size(num_qo_heads / paged_kv.num_heads),
377389
num_qo_heads(num_qo_heads),
378390
q_stride_n(q_stride_n),

include/flashinfer/attention/variants.cuh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ struct DefaultAttention : AttentionVariantBase {
9090
}
9191
return mask;
9292
})
93+
94+
REGISTER_M_D_UPDATE(params, kv_tile_idx, qo_head_idx, m, d, scale, {
95+
if constexpr (use_softmax) {
96+
if (params.maybe_s_aux != nullptr) {
97+
constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
98+
float s_aux_val = params.maybe_s_aux[qo_head_idx];
99+
d += math::ptx_exp2((s_aux_val - m) * LOG2_E);
100+
}
101+
}
102+
})
93103
};
94104

95105
}; // namespace flashinfer

tests/attention/test_decode_sink_attention.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def test_batch_decode_without_sink_attention(
257257

258258
@pytest.mark.parametrize("batch_size", [2])
259259
@pytest.mark.parametrize("kv_len", [64])
260-
@pytest.mark.parametrize("num_qo_heads", [8])
260+
@pytest.mark.parametrize("num_qo_heads", [16])
261261
@pytest.mark.parametrize("num_kv_heads", [8])
262262
@pytest.mark.parametrize("head_dim", [64])
263263
def test_batch_decode_sink_attention_gqa(
@@ -321,5 +321,88 @@ def test_batch_decode_sink_attention_gqa(
321321
assert not torch.isinf(out).any()
322322

323323

324+
@pytest.mark.parametrize("kv_len", [32, 128, 512])
325+
@pytest.mark.parametrize(
326+
"num_qo_heads,num_kv_heads",
327+
[
328+
(8, 8), # MHA: equal heads
329+
(16, 8), # GQA: 2:1 ratio
330+
(32, 8), # GQA: 4:1 ratio
331+
(32, 32), # MHA: equal heads
332+
],
333+
)
334+
@pytest.mark.parametrize("head_dim", [64, 128])
335+
@pytest.mark.parametrize("kv_layout", ["NHD", "HND"])
336+
def test_single_decode_sink_attention_tensor_cores(
337+
kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout
338+
):
339+
"""Test sink attention with single decode using tensor cores (prefill template)."""
340+
torch.manual_seed(42)
341+
device = torch.device("cuda:0")
342+
dtype = torch.bfloat16
343+
344+
sm_scale = 1.0 / math.sqrt(head_dim)
345+
window_left = -1 # No sliding window
346+
347+
# Create query tensor
348+
q = torch.randn(num_qo_heads, head_dim, dtype=dtype, device=device)
349+
350+
# Create KV cache based on layout
351+
if kv_layout == "NHD":
352+
k = torch.randn(kv_len, num_kv_heads, head_dim, dtype=dtype, device=device)
353+
v = torch.randn(kv_len, num_kv_heads, head_dim, dtype=dtype, device=device)
354+
else: # HND
355+
k = torch.randn(num_kv_heads, kv_len, head_dim, dtype=dtype, device=device)
356+
v = torch.randn(num_kv_heads, kv_len, head_dim, dtype=dtype, device=device)
357+
358+
# Sink tensor should have num_qo_heads elements
359+
# Sink values should be on similar scale to logits (QK^T * sm_scale)
360+
sinks = torch.randn(num_qo_heads, device=device, dtype=torch.float32) * 0.5
361+
362+
# Test with tensor cores enabled (uses prefill template)
363+
out = flashinfer.single_decode_with_kv_cache(
364+
q,
365+
k,
366+
v,
367+
kv_layout=kv_layout,
368+
pos_encoding_mode="NONE",
369+
use_tensor_cores=True,
370+
sm_scale=sm_scale,
371+
sinks=sinks,
372+
)
373+
374+
# Basic sanity check: output should have correct shape
375+
assert out.shape == (num_qo_heads, head_dim)
376+
assert out.dtype == dtype
377+
assert not torch.isnan(out).any()
378+
assert not torch.isinf(out).any()
379+
380+
# Validate against reference implementation
381+
# Convert to batch format for reference (add batch dimension)
382+
q_batch = q.unsqueeze(0) # [1, num_qo_heads, head_dim]
383+
384+
# Convert KV cache to reference format [batch_size, kv_len, num_kv_heads, head_dim]
385+
if kv_layout == "NHD":
386+
k_cache_ref = k.unsqueeze(0) # [1, kv_len, num_kv_heads, head_dim]
387+
v_cache_ref = v.unsqueeze(0) # [1, kv_len, num_kv_heads, head_dim]
388+
else: # HND -> transpose to NHD
389+
k_cache_ref = k.transpose(0, 1).unsqueeze(0) # [1, kv_len, num_kv_heads, head_dim]
390+
v_cache_ref = v.transpose(0, 1).unsqueeze(0) # [1, kv_len, num_kv_heads, head_dim]
391+
392+
# Compute reference output
393+
out_ref = sink_attention_decode_ref(
394+
q_batch, k_cache_ref, v_cache_ref, sinks, window_left, sm_scale
395+
)
396+
397+
# Remove batch dimension from reference output
398+
out_ref = out_ref.squeeze(0) # [num_qo_heads, head_dim]
399+
400+
# Compare results
401+
# bfloat16 may have slightly larger numerical differences due to lower precision,
402+
# differences in order of operations between reference and CUDA kernel, and
403+
# GQA scenarios where multiple query heads share KV heads
404+
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=3.5e-2)
405+
406+
324407
if __name__ == "__main__":
325408
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)