Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c3f8dd4
modify shape of kv cache
jvlunteren Jul 18, 2025
7d52de7
removed prefill support from split-kv attention
jvlunteren Jul 18, 2025
41a1254
added reorder_batch method
jvlunteren Jul 18, 2025
6109a14
added tiling to support large non-power-of-2 block sizes
jvlunteren Jul 18, 2025
343ed93
formatting
jvlunteren Jul 18, 2025
a96e229
updated parameters
jvlunteren Jul 18, 2025
ffa29db
resolved conflicts
jvlunteren Jul 21, 2025
6a375ee
removed unneeded files
jvlunteren Jul 21, 2025
39e352d
Merge branch 'vllm-project:main' into jvl-hybrid-support
jvlunteren Jul 22, 2025
c4f7d67
add prefill support back to split-kv kernel
jvlunteren Jul 29, 2025
b4fd773
removed changes to triton_attn.py
jvlunteren Aug 2, 2025
40218f4
Merge branch 'vllm-project:main' into jvl-hybrid-support
jvlunteren Aug 2, 2025
832ccb8
restored changes to triton_attn.py
jvlunteren Aug 2, 2025
d893bcb
formatting
jvlunteren Aug 2, 2025
02b0464
Merge branch 'main' into jvl-hybrid-support
jvlunteren Aug 4, 2025
9b63155
formatting
jvlunteren Aug 4, 2025
d3979e9
Merge branch 'main' into jvl-hybrid-support
jvlunteren Aug 16, 2025
771f524
Merge branch 'main' into jvl-hybrid-support
jvlunteren Aug 18, 2025
ec62011
Merge branch 'main' into jvl-hybrid-support
jvlunteren Aug 18, 2025
39ed6b6
temporarily revert changes to enable simpler merge with latest main
jvlunteren Sep 16, 2025
eead28e
Merge branch 'vllm-project:main' into jvl-hybrid-support
jvlunteren Sep 16, 2025
2b415b6
Merge branch 'vllm-project:main' into jvl-hybrid-support
jvlunteren Sep 16, 2025
4254fad
restore changes to triton_unified_attention.py
jvlunteren Sep 16, 2025
db61501
Merge branch 'main' into jvl-hybrid-support
jvlunteren Sep 16, 2025
3605c65
Merge branch 'main' into jvl-hybrid-support
jvlunteren Sep 17, 2025
8a846bc
replaced block size check by tile size check
jvlunteren Sep 17, 2025
f41f90e
Merge branch 'main' into jvl-hybrid-support
jvlunteren Sep 17, 2025
3dedd20
assign default tile sizes for prefill and decode, remove check
jvlunteren Sep 17, 2025
ee82dd5
Merge branch 'main' into jvl-hybrid-support
jvlunteren Sep 17, 2025
8da5f74
Merge branch 'main' into jvl-hybrid-support
jvlunteren Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions tests/kernels/attention/test_triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@ def test_triton_unified_attn(
) -> None:
torch.set_default_device("cuda")

if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32:
pytest.skip("block size must be at least 32 for fp8")

current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
Expand Down
122 changes: 70 additions & 52 deletions vllm/attention/ops/triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def kernel_unified_attention_2d(
output_stride_1: tl.int64, # int, should be equal to head_size
qq_bias_stride_0: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int
TILE_SIZE: tl.constexpr, # int must be power of 2
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
Expand Down Expand Up @@ -118,6 +119,7 @@ def kernel_unified_attention_2d(

offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
offs_t = tl.arange(0, TILE_SIZE)
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv

query_offset_0 = cur_batch_in_all_start_index + query_pos
Expand Down Expand Up @@ -177,31 +179,32 @@ def kernel_unified_attention_2d(
# actual sequence length
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)

# calculate the number of tiles (blocks) that need to be processed to
# cover the longest sequence prefix (due to causal masking, blocks beyond
# calculate the number of tiles that need to be processed to
# cover the longest sequence prefix (due to causal masking, tiles beyond
# this prefix can be skipped)
num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE)
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)

# iterate through tiles
for j in range(0, num_blocks):
for j in range(0, num_tiles):
seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len

physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
seq_offset // BLOCK_SIZE).to(tl.int64)

offs_n = tl.arange(0, BLOCK_SIZE)

v_offset = (physical_block_idx * stride_v_cache_0 +
v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
kv_head_idx * stride_v_cache_2 +
offs_d[None, :] * stride_v_cache_3 +
offs_n[:, None] * stride_v_cache_1)
(seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)

k_offset = (physical_block_idx * stride_k_cache_0 +
k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
kv_head_idx * stride_k_cache_2 +
offs_d[:, None] * stride_k_cache_3 +
offs_n[None, :] * stride_k_cache_1)
(seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)

# K : (HEAD_SIZE, BLOCK_SIZE)
# K : (HEAD_SIZE, TILE_SIZE)
K_load = tl.load(key_cache_ptr + k_offset,
mask=dim_mask[:, None],
mask=dim_mask[:, None] & tile_mask[None, :],
other=0.0)

if K_load.dtype.is_fp8():
Expand All @@ -212,9 +215,9 @@ def kernel_unified_attention_2d(
else:
K = K_load

# V : (BLOCK_SIZE, HEAD_SIZE)
# V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load(value_cache_ptr + v_offset,
mask=dim_mask[None, :],
mask=dim_mask[None, :] & tile_mask[:, None],
other=0.0)

if V_load.dtype.is_fp8():
Expand All @@ -225,12 +228,10 @@ def kernel_unified_attention_2d(
else:
V = V_load

seq_offset = j * BLOCK_SIZE + offs_n

seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1

# S : (BLOCK_M, BLOCK_SIZE)
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)
# S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)

S += scale * tl.dot(Q, K)

Expand Down Expand Up @@ -262,11 +263,12 @@ def kernel_unified_attention_2d(
# compute running maximum
# m_j : (BLOCK_M,)
m_j = tl.maximum(M, tl.max(S, axis=1))

# For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)

# P : (BLOCK_M, BLOCK_SIZE)
# P : (BLOCK_M, TILE_SIZE)
P = tl.exp(S - m_j[:, None])

# l_j : (BLOCK_M,)
Expand Down Expand Up @@ -327,6 +329,7 @@ def kernel_unified_attention_3d(
query_stride_1: tl.int64, # int, should be equal to head_size
qq_bias_stride_0: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int
TILE_SIZE: tl.constexpr, # int, must be power of 2
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
Expand Down Expand Up @@ -374,20 +377,19 @@ def kernel_unified_attention_3d(

# number of segments for this particular sequence
num_segments = NUM_SEGMENTS_PER_SEQ
blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE)
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)

if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len:
if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len:
return

offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)

offs_t = tl.arange(0, TILE_SIZE)
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv

query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + \
offs_m % num_queries_per_kv

query_offset = (query_offset_0[:, None] * query_stride_0 +
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])

Expand Down Expand Up @@ -433,30 +435,44 @@ def kernel_unified_attention_3d(
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
) # shape: [BLOCK_M]

num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
# compute the length of the longest sequence prefix spanned by any
# query token in the current q_block (q_block_local_idx)
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
BLOCK_M - 1) // num_queries_per_kv + 1

# adjust for potential padding in the last q_block by considering the
# actual sequence length
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)

# calculate the number of tiles that need to be processed to
# cover the longest sequence prefix (due to causal masking, tiles beyond
# this prefix can be skipped)
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come the 3D kernel didn't need to compute max_seq_prefix_len before this change? It looks like the 2D kernel did need to.

Copy link
Contributor Author

@jvlunteren jvlunteren Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The value max_seq_prefix_len relates to the number of tokens preceding the last query token in a Q block (for prefill, there can be multiple query tokens in a Q block), and is determined also considering potential padding. Because the 3D kernel is only used for decodes and therefore each Q block will only contain one query token, consequently the number of tokens preceding that single query token can be determined directly from the sequence length and max_seq_prefix_len does need to be calculated.

When I added prefill support back to 3d (split-kv) kernel, to simplify the review process (c4f7d67), I also added using max_seq_prefix_len back to the 3d kernel for the same purpose, by making the 2d and 3d kernels very similar.

In order to synchronize the branch of the fork involved in this PR with the vLLM main branch, it was easier to temporarily remove my updates, do the synchronization with the main vLLM branch, and then apply the updates again. Because the 3d kernel in the vLLM main branch does not use max_seq_prefix_len, it looks like the update only happened with the last change, but as indicated above, it was already included with commit c4f7d67. In a follow-up PR, I intend to simplify the 3d kernel by removing all functionality needed to support prefills, including the use of max_seq_prefix_len.


# iterate through tiles within current segment
for j in range(
segm_idx * blocks_per_segment,
min((segm_idx + 1) * blocks_per_segment, num_blocks),
segm_idx * tiles_per_segment,
min((segm_idx + 1) * tiles_per_segment, num_tiles),
):
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len

offs_n = tl.arange(0, BLOCK_SIZE)
physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
seq_offset // BLOCK_SIZE).to(tl.int64)

v_offset = (physical_block_idx * stride_v_cache_0 +
v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
kv_head_idx * stride_v_cache_2 +
offs_d[None, :] * stride_v_cache_3 +
offs_n[:, None] * stride_v_cache_1)
(seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)

k_offset = (physical_block_idx * stride_k_cache_0 +
k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
kv_head_idx * stride_k_cache_2 +
offs_d[:, None] * stride_k_cache_3 +
offs_n[None, :] * stride_k_cache_1)
(seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)

# K : (HEAD_SIZE, BLOCK_SIZE)
# K : (HEAD_SIZE, TILE_SIZE)
K_load = tl.load(key_cache_ptr + k_offset,
mask=dim_mask[:, None],
mask=dim_mask[:, None] & tile_mask[None, :],
other=0.0)

if K_load.dtype.is_fp8():
Expand All @@ -467,9 +483,9 @@ def kernel_unified_attention_3d(
else:
K = K_load

# V : (BLOCK_SIZE, HEAD_SIZE)
# V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load(value_cache_ptr + v_offset,
mask=dim_mask[None, :],
mask=dim_mask[None, :] & tile_mask[:, None],
other=0.0)

if V_load.dtype.is_fp8():
Expand All @@ -480,13 +496,10 @@ def kernel_unified_attention_3d(
else:
V = V_load

seq_offset = j * BLOCK_SIZE + offs_n

seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1

# S : (BLOCK_M, BLOCK_SIZE)
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)

# S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
S += scale * tl.dot(Q, K)

if USE_SOFTCAP:
Expand Down Expand Up @@ -517,11 +530,12 @@ def kernel_unified_attention_3d(
# compute running maximum
# m_j : (BLOCK_M,)
m_j = tl.maximum(M, tl.max(S, axis=1))

# For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)

# P : (BLOCK_M, BLOCK_SIZE,)
# P : (BLOCK_M, TILE_SIZE,)
P = tl.exp(S - m_j[:, None])

# l_j : (BLOCK_M,)
Expand Down Expand Up @@ -573,7 +587,7 @@ def reduce_segments(
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
block_table_stride: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int
TILE_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int, must be power of 2
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
query_start_len_ptr, # [num_seqs+1]
Expand All @@ -594,10 +608,10 @@ def reduce_segments(

# number of segments for this particular sequence
num_segments = NUM_SEGMENTS_PER_SEQ
blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE)
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)

# create masks for subsequent loads
act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE)
act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE)
segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full(
[NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32)
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
Expand Down Expand Up @@ -671,13 +685,10 @@ def unified_attention(
# Optional tensor for sinks
sinks=None,
):

assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported"

block_size = v.shape[1]
assert q.element_size() >= 2 or block_size >= 32, \
"Block size must be at least 32 for fp8"

if sinks is not None:
assert sinks.shape[0] == q.shape[1], \
"Sinks must be num_query_heads size"
Expand Down Expand Up @@ -707,6 +718,12 @@ def unified_attention(
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs

# Assigning default tile sizes for prefill and decode.
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
# and at least 16 for all other data types.
TILE_SIZE_PREFILL = 32
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32

# if batch contains a prefill
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
kernel_unified_attention_2d[(
Expand Down Expand Up @@ -736,6 +753,7 @@ def unified_attention(
output_stride_1=out.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_PREFILL,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
Expand Down Expand Up @@ -809,6 +827,7 @@ def unified_attention(
query_stride_1=q.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_DECODE,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
Expand All @@ -830,7 +849,6 @@ def unified_attention(
BLOCK_M=BLOCK_M,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
)

reduce_segments[(q.shape[0], num_query_heads)](
output_ptr=out,
segm_output_ptr=segm_output,
Expand All @@ -844,7 +862,7 @@ def unified_attention(
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
block_table_stride=block_table.stride(0),
BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_DECODE,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
query_start_len_ptr=cu_seqlens_q,
Expand Down