Skip to content

Commit 62859a9

Browse files
committed
Add maybe_s_aux to single decode module additional_tensor_names
Update single decode module's additional_tensor_names to include 'maybe_s_aux' along with the corresponding 'float' entry in additional_tensor_dtypes. This matches the batch decode definition and enables sink attention support for single decode operations.
1 parent 58f66f8 commit 62859a9

File tree

3 files changed

+20
-15
lines changed

3 files changed

+20
-15
lines changed

flashinfer/jit/attention/modules.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -467,8 +467,8 @@ def gen_single_decode_module(
467467
dtype_o,
468468
head_dim_qk,
469469
head_dim_vo,
470-
["maybe_alibi_slopes"], # additional_tensor_names
471-
["float"], # additional_tensor_dtypes
470+
["maybe_alibi_slopes", "maybe_s_aux"], # additional_tensor_names
471+
["float", "float"], # additional_tensor_dtypes
472472
[
473473
"logits_soft_cap",
474474
"sm_scale",
@@ -516,7 +516,11 @@ def gen_single_prefill_module(
516516

517517
if backend == "fa2":
518518
assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend"
519-
additional_tensor_names = ["maybe_custom_mask", "maybe_alibi_slopes", "maybe_s_aux"]
519+
additional_tensor_names = [
520+
"maybe_custom_mask",
521+
"maybe_alibi_slopes",
522+
"maybe_s_aux",
523+
]
520524
additional_tensor_dtypes = ["uint8_t", "float", "float"]
521525
additional_scalar_names = [
522526
"logits_soft_cap",

include/flashinfer/attention/default_prefill_params.cuh

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@ struct SinglePrefillParams {
8888
partition_kv(false) {}
8989

9090
__host__ SinglePrefillParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* maybe_custom_mask,
91-
DTypeO* o, float* lse, float* maybe_alibi_slopes,
92-
float* maybe_s_aux,
91+
DTypeO* o, float* lse, float* maybe_alibi_slopes, float* maybe_s_aux,
9392
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len,
9493
uint32_t kv_len, uint32_t q_stride_n, uint32_t q_stride_h,
9594
uint32_t kv_stride_n, uint32_t kv_stride_h, uint32_t head_dim,
@@ -230,10 +229,9 @@ struct BatchPrefillRaggedParams {
230229
IdType* q_indptr, IdType* kv_indptr, IdType* maybe_mask_indptr,
231230
IdType* maybe_q_rope_offset, IdType* maybe_k_rope_offset,
232231
DTypeO* o, float* lse, float* maybe_alibi_slopes,
233-
float* maybe_s_aux,
234-
uint32_t num_qo_heads, uint32_t num_kv_heads,
235-
uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n,
236-
uint32_t kv_stride_h, int32_t window_left,
232+
float* maybe_s_aux, uint32_t num_qo_heads,
233+
uint32_t num_kv_heads, uint32_t q_stride_n, uint32_t q_stride_h,
234+
uint32_t kv_stride_n, uint32_t kv_stride_h, int32_t window_left,
237235
float logits_soft_cap, float sm_scale, float rope_scale,
238236
float rope_theta)
239237
: q(q),
@@ -371,10 +369,9 @@ struct BatchPrefillPagedParams {
371369
uint8_t* maybe_custom_mask, IdType* q_indptr,
372370
IdType* maybe_mask_indptr, IdType* maybe_q_rope_offset,
373371
DTypeO* o, float* lse, float* maybe_alibi_slopes,
374-
float* maybe_s_aux,
375-
uint32_t num_qo_heads, IdType q_stride_n, IdType q_stride_h,
376-
int32_t window_left, float logits_soft_cap, float sm_scale,
377-
float rope_scale, float rope_theta)
372+
float* maybe_s_aux, uint32_t num_qo_heads, IdType q_stride_n,
373+
IdType q_stride_h, int32_t window_left, float logits_soft_cap,
374+
float sm_scale, float rope_scale, float rope_theta)
378375
: q(q),
379376
paged_kv(paged_kv),
380377
maybe_custom_mask(maybe_custom_mask),

tests/attention/test_decode_sink_attention.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,12 @@ def test_single_decode_sink_attention_tensor_cores(
386386
k_cache_ref = k.unsqueeze(0) # [1, kv_len, num_kv_heads, head_dim]
387387
v_cache_ref = v.unsqueeze(0) # [1, kv_len, num_kv_heads, head_dim]
388388
else: # HND -> transpose to NHD
389-
k_cache_ref = k.transpose(0, 1).unsqueeze(0) # [1, kv_len, num_kv_heads, head_dim]
390-
v_cache_ref = v.transpose(0, 1).unsqueeze(0) # [1, kv_len, num_kv_heads, head_dim]
389+
k_cache_ref = k.transpose(0, 1).unsqueeze(
390+
0
391+
) # [1, kv_len, num_kv_heads, head_dim]
392+
v_cache_ref = v.transpose(0, 1).unsqueeze(
393+
0
394+
) # [1, kv_len, num_kv_heads, head_dim]
391395

392396
# Compute reference output
393397
out_ref = sink_attention_decode_ref(

0 commit comments

Comments
 (0)