@@ -73,6 +73,7 @@ def kernel_unified_attention_2d(
7373 output_stride_1 : tl .int64 , # int, should be equal to head_size
7474 qq_bias_stride_0 : tl .int64 , # int
7575 BLOCK_SIZE : tl .constexpr , # int
76+ TILE_SIZE : tl .constexpr , # int must be power of 2
7677 HEAD_SIZE : tl .constexpr , # int
7778 HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
7879 USE_ALIBI_SLOPES : tl .constexpr , # bool
@@ -118,6 +119,7 @@ def kernel_unified_attention_2d(
118119
119120 offs_m = tl .arange (0 , BLOCK_M )
120121 offs_d = tl .arange (0 , HEAD_SIZE_PADDED )
122+ offs_t = tl .arange (0 , TILE_SIZE )
121123 query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
122124
123125 query_offset_0 = cur_batch_in_all_start_index + query_pos
@@ -177,31 +179,32 @@ def kernel_unified_attention_2d(
177179 # actual sequence length
178180 max_seq_prefix_len = tl .minimum (max_seq_prefix_len , seq_len )
179181
180- # calculate the number of tiles (blocks) that need to be processed to
181- # cover the longest sequence prefix (due to causal masking, blocks beyond
182+ # calculate the number of tiles that need to be processed to
183+ # cover the longest sequence prefix (due to causal masking, tiles beyond
182184 # this prefix can be skipped)
183- num_blocks = cdiv_fn (max_seq_prefix_len , BLOCK_SIZE )
185+ num_tiles = cdiv_fn (max_seq_prefix_len , TILE_SIZE )
184186
185187 # iterate through tiles
186- for j in range (0 , num_blocks ):
188+ for j in range (0 , num_tiles ):
189+ seq_offset = j * TILE_SIZE + offs_t
190+ tile_mask = seq_offset < max_seq_prefix_len
187191
188- physical_block_idx = tl .load (block_tables_ptr + block_table_offset + j )
192+ physical_block_idx = tl .load (block_tables_ptr + block_table_offset +
193+ seq_offset // BLOCK_SIZE ).to (tl .int64 )
189194
190- offs_n = tl .arange (0 , BLOCK_SIZE )
191-
192- v_offset = (physical_block_idx * stride_v_cache_0 +
195+ v_offset = (physical_block_idx [:, None ] * stride_v_cache_0 +
193196 kv_head_idx * stride_v_cache_2 +
194197 offs_d [None , :] * stride_v_cache_3 +
195- offs_n [:, None ] * stride_v_cache_1 )
198+ ( seq_offset % BLOCK_SIZE ) [:, None ] * stride_v_cache_1 )
196199
197- k_offset = (physical_block_idx * stride_k_cache_0 +
200+ k_offset = (physical_block_idx [ None , :] * stride_k_cache_0 +
198201 kv_head_idx * stride_k_cache_2 +
199202 offs_d [:, None ] * stride_k_cache_3 +
200- offs_n [None , :] * stride_k_cache_1 )
203+ ( seq_offset % BLOCK_SIZE ) [None , :] * stride_k_cache_1 )
201204
202- # K : (HEAD_SIZE, BLOCK_SIZE )
205+ # K : (HEAD_SIZE, TILE_SIZE )
203206 K_load = tl .load (key_cache_ptr + k_offset ,
204- mask = dim_mask [:, None ],
207+ mask = dim_mask [:, None ] & tile_mask [ None , :] ,
205208 other = 0.0 )
206209
207210 if K_load .dtype .is_fp8 ():
@@ -212,9 +215,9 @@ def kernel_unified_attention_2d(
212215 else :
213216 K = K_load
214217
215- # V : (BLOCK_SIZE , HEAD_SIZE)
218+ # V : (TILE_SIZE , HEAD_SIZE)
216219 V_load = tl .load (value_cache_ptr + v_offset ,
217- mask = dim_mask [None , :],
220+ mask = dim_mask [None , :] & tile_mask [:, None ] ,
218221 other = 0.0 )
219222
220223 if V_load .dtype .is_fp8 ():
@@ -225,12 +228,10 @@ def kernel_unified_attention_2d(
225228 else :
226229 V = V_load
227230
228- seq_offset = j * BLOCK_SIZE + offs_n
229-
230231 seq_mask = seq_offset [None , :] < context_len + query_pos [:, None ] + 1
231232
232- # S : (BLOCK_M, BLOCK_SIZE )
233- S = tl .zeros (shape = (BLOCK_M , BLOCK_SIZE ), dtype = tl .float32 )
233+ # S : (BLOCK_M, TILE_SIZE )
234+ S = tl .zeros (shape = (BLOCK_M , TILE_SIZE ), dtype = tl .float32 )
234235
235236 S += scale * tl .dot (Q , K )
236237
@@ -262,11 +263,12 @@ def kernel_unified_attention_2d(
262263 # compute running maximum
263264 # m_j : (BLOCK_M,)
264265 m_j = tl .maximum (M , tl .max (S , axis = 1 ))
266+
265267 # For sliding window there's a chance the max is -inf due to masking of
266268 # the entire row. In this case we need to set m_j 0 to avoid NaN
267269 m_j = tl .where (m_j > float ("-inf" ), m_j , 0.0 )
268270
269- # P : (BLOCK_M, BLOCK_SIZE )
271+ # P : (BLOCK_M, TILE_SIZE )
270272 P = tl .exp (S - m_j [:, None ])
271273
272274 # l_j : (BLOCK_M,)
@@ -327,6 +329,7 @@ def kernel_unified_attention_3d(
327329 query_stride_1 : tl .int64 , # int, should be equal to head_size
328330 qq_bias_stride_0 : tl .int64 , # int
329331 BLOCK_SIZE : tl .constexpr , # int
332+ TILE_SIZE : tl .constexpr , # int, must be power of 2
330333 HEAD_SIZE : tl .constexpr , # int
331334 HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
332335 USE_ALIBI_SLOPES : tl .constexpr , # bool
@@ -374,20 +377,19 @@ def kernel_unified_attention_3d(
374377
375378 # number of segments for this particular sequence
376379 num_segments = NUM_SEGMENTS_PER_SEQ
377- blocks_per_segment = cdiv_fn (seq_len , num_segments * BLOCK_SIZE )
380+ tiles_per_segment = cdiv_fn (seq_len , num_segments * TILE_SIZE )
378381
379- if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len :
382+ if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len :
380383 return
381384
382385 offs_m = tl .arange (0 , BLOCK_M )
383386 offs_d = tl .arange (0 , HEAD_SIZE_PADDED )
384-
387+ offs_t = tl . arange ( 0 , TILE_SIZE )
385388 query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
386389
387390 query_offset_0 = cur_batch_in_all_start_index + query_pos
388391 query_offset_1 = kv_head_idx * num_queries_per_kv + \
389392 offs_m % num_queries_per_kv
390-
391393 query_offset = (query_offset_0 [:, None ] * query_stride_0 +
392394 query_offset_1 [:, None ] * query_stride_1 + offs_d [None , :])
393395
@@ -433,30 +435,44 @@ def kernel_unified_attention_3d(
433435 qq_bias_row_ptrs = (qq_bias_ptr + query_pos [:, None ] * qq_bias_stride_0
434436 ) # shape: [BLOCK_M]
435437
436- num_blocks = cdiv_fn (seq_len , BLOCK_SIZE )
438+ # compute the length of the longest sequence prefix spanned by any
439+ # query token in the current q_block (q_block_local_idx)
440+ max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
441+ BLOCK_M - 1 ) // num_queries_per_kv + 1
442+
443+ # adjust for potential padding in the last q_block by considering the
444+ # actual sequence length
445+ max_seq_prefix_len = tl .minimum (max_seq_prefix_len , seq_len )
446+
447+ # calculate the number of tiles that need to be processed to
448+ # cover the longest sequence prefix (due to causal masking, tiles beyond
449+ # this prefix can be skipped)
450+ num_tiles = cdiv_fn (max_seq_prefix_len , TILE_SIZE )
437451
438452 # iterate through tiles within current segment
439453 for j in range (
440- segm_idx * blocks_per_segment ,
441- min ((segm_idx + 1 ) * blocks_per_segment , num_blocks ),
454+ segm_idx * tiles_per_segment ,
455+ min ((segm_idx + 1 ) * tiles_per_segment , num_tiles ),
442456 ):
443- physical_block_idx = tl .load (block_tables_ptr + block_table_offset + j )
457+ seq_offset = j * TILE_SIZE + offs_t
458+ tile_mask = seq_offset < max_seq_prefix_len
444459
445- offs_n = tl .arange (0 , BLOCK_SIZE )
460+ physical_block_idx = tl .load (block_tables_ptr + block_table_offset +
461+ seq_offset // BLOCK_SIZE ).to (tl .int64 )
446462
447- v_offset = (physical_block_idx * stride_v_cache_0 +
463+ v_offset = (physical_block_idx [:, None ] * stride_v_cache_0 +
448464 kv_head_idx * stride_v_cache_2 +
449465 offs_d [None , :] * stride_v_cache_3 +
450- offs_n [:, None ] * stride_v_cache_1 )
466+ ( seq_offset % BLOCK_SIZE ) [:, None ] * stride_v_cache_1 )
451467
452- k_offset = (physical_block_idx * stride_k_cache_0 +
468+ k_offset = (physical_block_idx [ None , :] * stride_k_cache_0 +
453469 kv_head_idx * stride_k_cache_2 +
454470 offs_d [:, None ] * stride_k_cache_3 +
455- offs_n [None , :] * stride_k_cache_1 )
471+ ( seq_offset % BLOCK_SIZE ) [None , :] * stride_k_cache_1 )
456472
457- # K : (HEAD_SIZE, BLOCK_SIZE )
473+ # K : (HEAD_SIZE, TILE_SIZE )
458474 K_load = tl .load (key_cache_ptr + k_offset ,
459- mask = dim_mask [:, None ],
475+ mask = dim_mask [:, None ] & tile_mask [ None , :] ,
460476 other = 0.0 )
461477
462478 if K_load .dtype .is_fp8 ():
@@ -467,9 +483,9 @@ def kernel_unified_attention_3d(
467483 else :
468484 K = K_load
469485
470- # V : (BLOCK_SIZE , HEAD_SIZE)
486+ # V : (TILE_SIZE , HEAD_SIZE)
471487 V_load = tl .load (value_cache_ptr + v_offset ,
472- mask = dim_mask [None , :],
488+ mask = dim_mask [None , :] & tile_mask [:, None ] ,
473489 other = 0.0 )
474490
475491 if V_load .dtype .is_fp8 ():
@@ -480,13 +496,10 @@ def kernel_unified_attention_3d(
480496 else :
481497 V = V_load
482498
483- seq_offset = j * BLOCK_SIZE + offs_n
484-
485499 seq_mask = seq_offset [None , :] < context_len + query_pos [:, None ] + 1
486500
487- # S : (BLOCK_M, BLOCK_SIZE)
488- S = tl .zeros (shape = (BLOCK_M , BLOCK_SIZE ), dtype = tl .float32 )
489-
501+ # S : (BLOCK_M, TILE_SIZE)
502+ S = tl .zeros (shape = (BLOCK_M , TILE_SIZE ), dtype = tl .float32 )
490503 S += scale * tl .dot (Q , K )
491504
492505 if USE_SOFTCAP :
@@ -517,11 +530,12 @@ def kernel_unified_attention_3d(
517530 # compute running maximum
518531 # m_j : (BLOCK_M,)
519532 m_j = tl .maximum (M , tl .max (S , axis = 1 ))
533+
520534 # For sliding window there's a chance the max is -inf due to masking of
521535 # the entire row. In this case we need to set m_j 0 to avoid NaN
522536 m_j = tl .where (m_j > float ("-inf" ), m_j , 0.0 )
523537
524- # P : (BLOCK_M, BLOCK_SIZE ,)
538+ # P : (BLOCK_M, TILE_SIZE ,)
525539 P = tl .exp (S - m_j [:, None ])
526540
527541 # l_j : (BLOCK_M,)
@@ -573,7 +587,7 @@ def reduce_segments(
573587 output_stride_0 : tl .int64 , # int
574588 output_stride_1 : tl .int64 , # int, should be equal to head_size
575589 block_table_stride : tl .int64 , # int
576- BLOCK_SIZE : tl .constexpr , # int
590+ TILE_SIZE : tl .constexpr , # int
577591 HEAD_SIZE : tl .constexpr , # int, must be power of 2
578592 HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
579593 query_start_len_ptr , # [num_seqs+1]
@@ -594,10 +608,10 @@ def reduce_segments(
594608
595609 # number of segments for this particular sequence
596610 num_segments = NUM_SEGMENTS_PER_SEQ
597- blocks_per_segment = cdiv_fn (seq_len , num_segments * BLOCK_SIZE )
611+ tiles_per_segment = cdiv_fn (seq_len , num_segments * TILE_SIZE )
598612
599613 # create masks for subsequent loads
600- act_num_segments = cdiv_fn (seq_len , blocks_per_segment * BLOCK_SIZE )
614+ act_num_segments = cdiv_fn (seq_len , tiles_per_segment * TILE_SIZE )
601615 segm_mask = tl .arange (0 , NUM_SEGMENTS_PER_SEQ ) < tl .full (
602616 [NUM_SEGMENTS_PER_SEQ ], act_num_segments , dtype = tl .int32 )
603617 dim_mask = tl .where (tl .arange (0 , HEAD_SIZE_PADDED ) < HEAD_SIZE , 1 ,
@@ -671,13 +685,10 @@ def unified_attention(
671685 # Optional tensor for sinks
672686 sinks = None ,
673687):
688+
674689 assert causal , "Only causal attention is supported"
675690 assert q_descale is None , "Q scales not supported"
676691
677- block_size = v .shape [1 ]
678- assert q .element_size () >= 2 or block_size >= 32 , \
679- "Block size must be at least 32 for fp8"
680-
681692 if sinks is not None :
682693 assert sinks .shape [0 ] == q .shape [1 ], \
683694 "Sinks must be num_query_heads size"
@@ -707,6 +718,12 @@ def unified_attention(
707718 # = floor(q.shape[0] / BLOCK_Q) + num_seqs
708719 total_num_q_blocks = q .shape [0 ] // BLOCK_Q + num_seqs
709720
721+ # Assigning default tile sizes for prefill and decode.
722+ # Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
723+ # and at least 16 for all other data types.
724+ TILE_SIZE_PREFILL = 32
725+ TILE_SIZE_DECODE = 16 if q .element_size () >= 2 else 32
726+
710727 # if batch contains a prefill
711728 if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128 :
712729 kernel_unified_attention_2d [(
@@ -736,6 +753,7 @@ def unified_attention(
736753 output_stride_1 = out .stride (1 ),
737754 qq_bias_stride_0 = qq_bias .stride (0 ) if use_qq_bias else 0 ,
738755 BLOCK_SIZE = block_size ,
756+ TILE_SIZE = TILE_SIZE_PREFILL ,
739757 HEAD_SIZE = head_size ,
740758 HEAD_SIZE_PADDED = triton .next_power_of_2 (head_size ),
741759 USE_ALIBI_SLOPES = use_alibi_slopes ,
@@ -809,6 +827,7 @@ def unified_attention(
809827 query_stride_1 = q .stride (1 ),
810828 qq_bias_stride_0 = qq_bias .stride (0 ) if use_qq_bias else 0 ,
811829 BLOCK_SIZE = block_size ,
830+ TILE_SIZE = TILE_SIZE_DECODE ,
812831 HEAD_SIZE = head_size ,
813832 HEAD_SIZE_PADDED = triton .next_power_of_2 (head_size ),
814833 USE_ALIBI_SLOPES = use_alibi_slopes ,
@@ -830,7 +849,6 @@ def unified_attention(
830849 BLOCK_M = BLOCK_M ,
831850 NUM_SEGMENTS_PER_SEQ = NUM_SEGMENTS ,
832851 )
833-
834852 reduce_segments [(q .shape [0 ], num_query_heads )](
835853 output_ptr = out ,
836854 segm_output_ptr = segm_output ,
@@ -844,7 +862,7 @@ def unified_attention(
844862 output_stride_0 = out .stride (0 ),
845863 output_stride_1 = out .stride (1 ),
846864 block_table_stride = block_table .stride (0 ),
847- BLOCK_SIZE = block_size ,
865+ TILE_SIZE = TILE_SIZE_DECODE ,
848866 HEAD_SIZE = head_size ,
849867 HEAD_SIZE_PADDED = triton .next_power_of_2 (head_size ),
850868 query_start_len_ptr = cu_seqlens_q ,
0 commit comments