@@ -52,6 +52,7 @@ def kernel_unified_attention_2d(
5252 query_ptr , # [num_tokens, num_query_heads, head_size]
5353 key_cache_ptr , # [num_blks, blk_size, num_kv_heads, head_size]
5454 value_cache_ptr , # [num_blks, blk_size, num_kv_heads, head_size]
55+ sink_ptr , # [num_query_heads]
5556 block_tables_ptr , # [num_seqs, max_num_blocks_per_seq]
5657 seq_lens_ptr , # [num_seqs]
5758 alibi_slopes_ptr , # [num_query_heads]
@@ -131,7 +132,15 @@ def kernel_unified_attention_2d(
131132
132133 block_table_offset = seq_idx * block_table_stride
133134
134- M = tl .full ([BLOCK_M ], float ("-inf" ), dtype = tl .float32 )
135+ if sink_ptr is None :
136+ M = tl .full ([BLOCK_M ], float ("-inf" ), dtype = tl .float32 )
137+ else :
138+ M = tl .load (
139+ sink_ptr + query_offset_1 ,
140+ mask = query_mask_1 ,
141+ other = float ("-inf" ),
142+ ).to (dtype = tl .float32 )
143+
135144 L = tl .full ([BLOCK_M ], 1.0 , dtype = tl .float32 )
136145 acc = tl .zeros ([BLOCK_M , HEAD_SIZE_PADDED ], dtype = tl .float32 )
137146
@@ -292,6 +301,7 @@ def kernel_unified_attention_3d(
292301 query_ptr , # [num_tokens, num_query_heads, head_size]
293302 key_cache_ptr , # [num_blks, num_kv_heads, head_size // x, blk_size, x]
294303 value_cache_ptr , # [num_blks, num_kv_heads, head_size, blk_size]
304+ sink_ptr , # [num_query_heads]
295305 block_tables_ptr , # [num_seqs, max_num_blocks_per_seq]
296306 seq_lens_ptr , # [num_seqs]
297307 alibi_slopes_ptr , # [num_query_heads]
@@ -383,7 +393,15 @@ def kernel_unified_attention_3d(
383393
384394 block_table_offset = seq_idx * block_table_stride
385395
386- M = tl .full ([BLOCK_M ], float ("-inf" ), dtype = tl .float32 )
396+ if sink_ptr is None or segm_idx != 0 :
397+ M = tl .full ([BLOCK_M ], float ("-inf" ), dtype = tl .float32 )
398+ else :
399+ M = tl .load (
400+ sink_ptr + query_offset_1 ,
401+ mask = query_mask_1 ,
402+ other = float ("-inf" ),
403+ ).to (dtype = tl .float32 )
404+
387405 L = tl .full ([BLOCK_M ], 1.0 , dtype = tl .float32 )
388406 acc = tl .zeros ([BLOCK_M , HEAD_SIZE_PADDED ], dtype = tl .float32 )
389407
@@ -627,6 +645,8 @@ def unified_attention(
627645 v_descale ,
628646 alibi_slopes = None ,
629647 qq_bias = None ,
648+ # Optional tensor for sinks
649+ sinks = None ,
630650):
631651 assert causal , "Only causal attention is supported"
632652 assert q_descale is None , "Q scales not supported"
@@ -635,6 +655,10 @@ def unified_attention(
635655 assert q .element_size () >= 2 or block_size >= 32 , \
636656 "Block size must be at least 32 for fp8"
637657
658+ if sinks is not None :
659+ assert sinks .shape [0 ] == q .shape [1 ], \
660+ "Sinks must be num_query_heads size"
661+
638662 use_alibi_slopes = alibi_slopes is not None
639663 use_qq_bias = qq_bias is not None
640664
@@ -669,6 +693,7 @@ def unified_attention(
669693 query_ptr = q ,
670694 key_cache_ptr = k ,
671695 value_cache_ptr = v ,
696+ sink_ptr = sinks ,
672697 block_tables_ptr = block_table ,
673698 seq_lens_ptr = seqused_k ,
674699 alibi_slopes_ptr = alibi_slopes ,
@@ -741,6 +766,7 @@ def unified_attention(
741766 query_ptr = q ,
742767 key_cache_ptr = k ,
743768 value_cache_ptr = v ,
769+ sink_ptr = sinks ,
744770 block_tables_ptr = block_table ,
745771 seq_lens_ptr = seqused_k ,
746772 alibi_slopes_ptr = alibi_slopes ,
0 commit comments