Skip to content

Commit a0ee87e

Browse files
jvlunterencharlifu
authored andcommitted
[Kernel] Decouple Tile Size from Block Size in Triton Unified Attention Kernel (vllm-project#21197)
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com> Signed-off-by: charlifu <charlifu@amd.com>
1 parent b50707a commit a0ee87e

File tree

2 files changed

+70
-55
lines changed

2 files changed

+70
-55
lines changed

tests/kernels/attention/test_triton_unified_attention.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,6 @@ def test_triton_unified_attn(
102102
) -> None:
103103
torch.set_default_device("cuda")
104104

105-
if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32:
106-
pytest.skip("block size must be at least 32 for fp8")
107-
108105
current_platform.seed_everything(0)
109106
num_seqs = len(seq_lens)
110107
query_lens = [x[0] for x in seq_lens]

vllm/attention/ops/triton_unified_attention.py

Lines changed: 70 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def kernel_unified_attention_2d(
7373
output_stride_1: tl.int64, # int, should be equal to head_size
7474
qq_bias_stride_0: tl.int64, # int
7575
BLOCK_SIZE: tl.constexpr, # int
76+
TILE_SIZE: tl.constexpr, # int must be power of 2
7677
HEAD_SIZE: tl.constexpr, # int
7778
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
7879
USE_ALIBI_SLOPES: tl.constexpr, # bool
@@ -118,6 +119,7 @@ def kernel_unified_attention_2d(
118119

119120
offs_m = tl.arange(0, BLOCK_M)
120121
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
122+
offs_t = tl.arange(0, TILE_SIZE)
121123
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
122124

123125
query_offset_0 = cur_batch_in_all_start_index + query_pos
@@ -177,31 +179,32 @@ def kernel_unified_attention_2d(
177179
# actual sequence length
178180
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
179181

180-
# calculate the number of tiles (blocks) that need to be processed to
181-
# cover the longest sequence prefix (due to causal masking, blocks beyond
182+
# calculate the number of tiles that need to be processed to
183+
# cover the longest sequence prefix (due to causal masking, tiles beyond
182184
# this prefix can be skipped)
183-
num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE)
185+
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
184186

185187
# iterate through tiles
186-
for j in range(0, num_blocks):
188+
for j in range(0, num_tiles):
189+
seq_offset = j * TILE_SIZE + offs_t
190+
tile_mask = seq_offset < max_seq_prefix_len
187191

188-
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
192+
physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
193+
seq_offset // BLOCK_SIZE).to(tl.int64)
189194

190-
offs_n = tl.arange(0, BLOCK_SIZE)
191-
192-
v_offset = (physical_block_idx * stride_v_cache_0 +
195+
v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
193196
kv_head_idx * stride_v_cache_2 +
194197
offs_d[None, :] * stride_v_cache_3 +
195-
offs_n[:, None] * stride_v_cache_1)
198+
(seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)
196199

197-
k_offset = (physical_block_idx * stride_k_cache_0 +
200+
k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
198201
kv_head_idx * stride_k_cache_2 +
199202
offs_d[:, None] * stride_k_cache_3 +
200-
offs_n[None, :] * stride_k_cache_1)
203+
(seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)
201204

202-
# K : (HEAD_SIZE, BLOCK_SIZE)
205+
# K : (HEAD_SIZE, TILE_SIZE)
203206
K_load = tl.load(key_cache_ptr + k_offset,
204-
mask=dim_mask[:, None],
207+
mask=dim_mask[:, None] & tile_mask[None, :],
205208
other=0.0)
206209

207210
if K_load.dtype.is_fp8():
@@ -212,9 +215,9 @@ def kernel_unified_attention_2d(
212215
else:
213216
K = K_load
214217

215-
# V : (BLOCK_SIZE, HEAD_SIZE)
218+
# V : (TILE_SIZE, HEAD_SIZE)
216219
V_load = tl.load(value_cache_ptr + v_offset,
217-
mask=dim_mask[None, :],
220+
mask=dim_mask[None, :] & tile_mask[:, None],
218221
other=0.0)
219222

220223
if V_load.dtype.is_fp8():
@@ -225,12 +228,10 @@ def kernel_unified_attention_2d(
225228
else:
226229
V = V_load
227230

228-
seq_offset = j * BLOCK_SIZE + offs_n
229-
230231
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
231232

232-
# S : (BLOCK_M, BLOCK_SIZE)
233-
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)
233+
# S : (BLOCK_M, TILE_SIZE)
234+
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
234235

235236
S += scale * tl.dot(Q, K)
236237

@@ -262,11 +263,12 @@ def kernel_unified_attention_2d(
262263
# compute running maximum
263264
# m_j : (BLOCK_M,)
264265
m_j = tl.maximum(M, tl.max(S, axis=1))
266+
265267
# For sliding window there's a chance the max is -inf due to masking of
266268
# the entire row. In this case we need to set m_j 0 to avoid NaN
267269
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
268270

269-
# P : (BLOCK_M, BLOCK_SIZE)
271+
# P : (BLOCK_M, TILE_SIZE)
270272
P = tl.exp(S - m_j[:, None])
271273

272274
# l_j : (BLOCK_M,)
@@ -327,6 +329,7 @@ def kernel_unified_attention_3d(
327329
query_stride_1: tl.int64, # int, should be equal to head_size
328330
qq_bias_stride_0: tl.int64, # int
329331
BLOCK_SIZE: tl.constexpr, # int
332+
TILE_SIZE: tl.constexpr, # int, must be power of 2
330333
HEAD_SIZE: tl.constexpr, # int
331334
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
332335
USE_ALIBI_SLOPES: tl.constexpr, # bool
@@ -374,20 +377,19 @@ def kernel_unified_attention_3d(
374377

375378
# number of segments for this particular sequence
376379
num_segments = NUM_SEGMENTS_PER_SEQ
377-
blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE)
380+
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
378381

379-
if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len:
382+
if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len:
380383
return
381384

382385
offs_m = tl.arange(0, BLOCK_M)
383386
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
384-
387+
offs_t = tl.arange(0, TILE_SIZE)
385388
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
386389

387390
query_offset_0 = cur_batch_in_all_start_index + query_pos
388391
query_offset_1 = kv_head_idx * num_queries_per_kv + \
389392
offs_m % num_queries_per_kv
390-
391393
query_offset = (query_offset_0[:, None] * query_stride_0 +
392394
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
393395

@@ -433,30 +435,44 @@ def kernel_unified_attention_3d(
433435
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
434436
) # shape: [BLOCK_M]
435437

436-
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
438+
# compute the length of the longest sequence prefix spanned by any
439+
# query token in the current q_block (q_block_local_idx)
440+
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
441+
BLOCK_M - 1) // num_queries_per_kv + 1
442+
443+
# adjust for potential padding in the last q_block by considering the
444+
# actual sequence length
445+
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
446+
447+
# calculate the number of tiles that need to be processed to
448+
# cover the longest sequence prefix (due to causal masking, tiles beyond
449+
# this prefix can be skipped)
450+
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
437451

438452
# iterate through tiles within current segment
439453
for j in range(
440-
segm_idx * blocks_per_segment,
441-
min((segm_idx + 1) * blocks_per_segment, num_blocks),
454+
segm_idx * tiles_per_segment,
455+
min((segm_idx + 1) * tiles_per_segment, num_tiles),
442456
):
443-
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
457+
seq_offset = j * TILE_SIZE + offs_t
458+
tile_mask = seq_offset < max_seq_prefix_len
444459

445-
offs_n = tl.arange(0, BLOCK_SIZE)
460+
physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
461+
seq_offset // BLOCK_SIZE).to(tl.int64)
446462

447-
v_offset = (physical_block_idx * stride_v_cache_0 +
463+
v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
448464
kv_head_idx * stride_v_cache_2 +
449465
offs_d[None, :] * stride_v_cache_3 +
450-
offs_n[:, None] * stride_v_cache_1)
466+
(seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)
451467

452-
k_offset = (physical_block_idx * stride_k_cache_0 +
468+
k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
453469
kv_head_idx * stride_k_cache_2 +
454470
offs_d[:, None] * stride_k_cache_3 +
455-
offs_n[None, :] * stride_k_cache_1)
471+
(seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)
456472

457-
# K : (HEAD_SIZE, BLOCK_SIZE)
473+
# K : (HEAD_SIZE, TILE_SIZE)
458474
K_load = tl.load(key_cache_ptr + k_offset,
459-
mask=dim_mask[:, None],
475+
mask=dim_mask[:, None] & tile_mask[None, :],
460476
other=0.0)
461477

462478
if K_load.dtype.is_fp8():
@@ -467,9 +483,9 @@ def kernel_unified_attention_3d(
467483
else:
468484
K = K_load
469485

470-
# V : (BLOCK_SIZE, HEAD_SIZE)
486+
# V : (TILE_SIZE, HEAD_SIZE)
471487
V_load = tl.load(value_cache_ptr + v_offset,
472-
mask=dim_mask[None, :],
488+
mask=dim_mask[None, :] & tile_mask[:, None],
473489
other=0.0)
474490

475491
if V_load.dtype.is_fp8():
@@ -480,13 +496,10 @@ def kernel_unified_attention_3d(
480496
else:
481497
V = V_load
482498

483-
seq_offset = j * BLOCK_SIZE + offs_n
484-
485499
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
486500

487-
# S : (BLOCK_M, BLOCK_SIZE)
488-
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)
489-
501+
# S : (BLOCK_M, TILE_SIZE)
502+
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
490503
S += scale * tl.dot(Q, K)
491504

492505
if USE_SOFTCAP:
@@ -517,11 +530,12 @@ def kernel_unified_attention_3d(
517530
# compute running maximum
518531
# m_j : (BLOCK_M,)
519532
m_j = tl.maximum(M, tl.max(S, axis=1))
533+
520534
# For sliding window there's a chance the max is -inf due to masking of
521535
# the entire row. In this case we need to set m_j 0 to avoid NaN
522536
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
523537

524-
# P : (BLOCK_M, BLOCK_SIZE,)
538+
# P : (BLOCK_M, TILE_SIZE,)
525539
P = tl.exp(S - m_j[:, None])
526540

527541
# l_j : (BLOCK_M,)
@@ -573,7 +587,7 @@ def reduce_segments(
573587
output_stride_0: tl.int64, # int
574588
output_stride_1: tl.int64, # int, should be equal to head_size
575589
block_table_stride: tl.int64, # int
576-
BLOCK_SIZE: tl.constexpr, # int
590+
TILE_SIZE: tl.constexpr, # int
577591
HEAD_SIZE: tl.constexpr, # int, must be power of 2
578592
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
579593
query_start_len_ptr, # [num_seqs+1]
@@ -594,10 +608,10 @@ def reduce_segments(
594608

595609
# number of segments for this particular sequence
596610
num_segments = NUM_SEGMENTS_PER_SEQ
597-
blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE)
611+
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
598612

599613
# create masks for subsequent loads
600-
act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE)
614+
act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE)
601615
segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full(
602616
[NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32)
603617
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
@@ -671,13 +685,10 @@ def unified_attention(
671685
# Optional tensor for sinks
672686
sinks=None,
673687
):
688+
674689
assert causal, "Only causal attention is supported"
675690
assert q_descale is None, "Q scales not supported"
676691

677-
block_size = v.shape[1]
678-
assert q.element_size() >= 2 or block_size >= 32, \
679-
"Block size must be at least 32 for fp8"
680-
681692
if sinks is not None:
682693
assert sinks.shape[0] == q.shape[1], \
683694
"Sinks must be num_query_heads size"
@@ -707,6 +718,12 @@ def unified_attention(
707718
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
708719
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
709720

721+
# Assigning default tile sizes for prefill and decode.
722+
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
723+
# and at least 16 for all other data types.
724+
TILE_SIZE_PREFILL = 32
725+
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
726+
710727
# if batch contains a prefill
711728
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
712729
kernel_unified_attention_2d[(
@@ -736,6 +753,7 @@ def unified_attention(
736753
output_stride_1=out.stride(1),
737754
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
738755
BLOCK_SIZE=block_size,
756+
TILE_SIZE=TILE_SIZE_PREFILL,
739757
HEAD_SIZE=head_size,
740758
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
741759
USE_ALIBI_SLOPES=use_alibi_slopes,
@@ -809,6 +827,7 @@ def unified_attention(
809827
query_stride_1=q.stride(1),
810828
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
811829
BLOCK_SIZE=block_size,
830+
TILE_SIZE=TILE_SIZE_DECODE,
812831
HEAD_SIZE=head_size,
813832
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
814833
USE_ALIBI_SLOPES=use_alibi_slopes,
@@ -830,7 +849,6 @@ def unified_attention(
830849
BLOCK_M=BLOCK_M,
831850
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
832851
)
833-
834852
reduce_segments[(q.shape[0], num_query_heads)](
835853
output_ptr=out,
836854
segm_output_ptr=segm_output,
@@ -844,7 +862,7 @@ def unified_attention(
844862
output_stride_0=out.stride(0),
845863
output_stride_1=out.stride(1),
846864
block_table_stride=block_table.stride(0),
847-
BLOCK_SIZE=block_size,
865+
TILE_SIZE=TILE_SIZE_DECODE,
848866
HEAD_SIZE=head_size,
849867
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
850868
query_start_len_ptr=cu_seqlens_q,

0 commit comments

Comments
 (0)