From c3f8dd47789b2fd2655ffd27e7af7736f0c1725a Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Fri, 18 Jul 2025 00:48:10 -0400 Subject: [PATCH 01/16] modify shape of kv cache Signed-off-by: Jan van Lunteren --- vllm/v1/attention/backends/triton_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 79796ac14928..709adac2aa5d 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -183,7 +183,7 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + return (num_blocks, 2, block_size, num_kv_heads, head_size) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: @@ -299,7 +299,7 @@ def forward( key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) else: - key_cache, value_cache = kv_cache.unbind(0) + key_cache, value_cache = kv_cache.unbind(1) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. From 7d52de7eed4482de192902d56e9abefc60221bcd Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Fri, 18 Jul 2025 05:13:29 -0400 Subject: [PATCH 02/16] removed prefill support from split-kv attention Signed-off-by: Jan van Lunteren --- .../attention/ops/triton_unified_attention.py | 143 +++++++----------- 1 file changed, 51 insertions(+), 92 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index eb9c4f1c1030..cd5cec12b404 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -281,6 +281,7 @@ def kernel_unified_attention_3d( softcap, # float32 num_query_heads: tl.constexpr, # int num_queries_per_kv: tl.constexpr, # int + num_queries_per_kv_padded: tl.constexpr, # int block_table_stride: tl.int64, # int query_stride_0: tl.int64, # int query_stride_1: tl.int64, # int, should be equal to head_size @@ -298,33 +299,12 @@ def kernel_unified_attention_3d( stride_v_cache_1: tl.int64, # int stride_v_cache_2: tl.int64, # int stride_v_cache_3: tl.constexpr, # int - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int32, - BLOCK_M: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int ): - q_block_global_idx = tl.program_id(0) + seq_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) segm_idx = tl.program_id(2) - seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, - BLOCK_Q, True) - - q_block_start_idx = tl.load(query_start_len_ptr + - seq_idx) // BLOCK_Q + seq_idx - - q_block_local_idx = q_block_global_idx - q_block_start_idx - - cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) - cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) - - cur_batch_query_len = cur_batch_in_all_stop_index \ - - cur_batch_in_all_start_index - - if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: - return - # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) @@ -335,42 +315,35 @@ def kernel_unified_attention_3d( if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len: return - offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_n = tl.arange(0, BLOCK_SIZE) - query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv - - query_offset_0 = cur_batch_in_all_start_index + query_pos - query_offset_1 = kv_head_idx * num_queries_per_kv + \ - offs_m % num_queries_per_kv + query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(0, num_queries_per_kv_padded) + query_offset = seq_idx * query_stride_0 + query_head_idx[:, None] * query_stride_1 + offs_d[None, :] - query_offset = (query_offset_0[:, None] * query_stride_0 + - query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv + head_mask = head_mask & (query_head_idx < num_query_heads) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) - query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) - query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) - # Q : (BLOCK_M, HEAD_SIZE_PADDED) + # Q : (num_queries_per_kv_padded, HEAD_SIZE_PADDED) Q = tl.load( query_ptr + query_offset, - mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + mask=dim_mask[None, :] & head_mask[:, None], other=0.0, ) block_table_offset = seq_idx * block_table_stride - M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) - - # context length for this particular sequences - context_len = seq_len - cur_batch_query_len + M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) + L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) + acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], + dtype=tl.float32) # alibi slope for this head if USE_ALIBI_SLOPES: - alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, - mask=query_mask_1, + alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx, + mask=head_mask, other=0.0) num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) @@ -382,8 +355,6 @@ def kernel_unified_attention_3d( ): physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) - offs_n = tl.arange(0, BLOCK_SIZE) - v_offset = (physical_block_idx * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 + offs_d[None, :] * stride_v_cache_3 + @@ -421,81 +392,80 @@ def kernel_unified_attention_3d( V = V_load seq_offset = j * BLOCK_SIZE + offs_n + seq_mask = seq_offset[None, :] < seq_len - seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - - # S : (BLOCK_M, BLOCK_SIZE) - S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) - + # S : (num_queries_per_kv_padded, BLOCK_SIZE) + S = tl.zeros(shape=(num_queries_per_kv_padded, BLOCK_SIZE), dtype=tl.float32) S += scale * tl.dot(Q, K) + context_len = seq_len - 1 + if USE_SOFTCAP: S = apply_softcap(S, softcap) - S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, - S, float("-inf")) + S = tl.where(head_mask[:, None] & seq_mask, S, float("-inf")).to(tl.float32) if SLIDING_WINDOW > 0: - S = tl.where((context_len + query_pos[:, None] - seq_offset) - < SLIDING_WINDOW, S, float("-inf")) + S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, + float("-inf")) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) # compute running maximum - # m_j : (BLOCK_M,) + # m_j : (num_queries_per_kv_padded,) m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_M, BLOCK_SIZE,) + # P : (num_queries_per_kv_padded, BLOCK_SIZE,) P = tl.exp(S - m_j[:, None]) - # l_j : (BLOCK_M,) + # l_j : (num_queries_per_kv_padded,) l_j = tl.sum(P, axis=1) - # alpha : (BLOCK_M, ) + # alpha : (num_queries_per_kv_padded, ) alpha = tl.exp(M - m_j) - # acc : (BLOCK_M, HEAD_SIZE_PADDED) + # acc : (num_queries_per_kv_padded, HEAD_SIZE_PADDED) acc = acc * alpha[:, None] # update constants L = L * alpha + l_j M = m_j - # acc : (BLOCK_M, HEAD_SIZE_PADDED) + # acc : (num_queries_per_kv_padded, HEAD_SIZE_PADDED) acc += tl.dot(P.to(V.dtype), V) segm_output_offset = ( - query_offset_0[:, None].to(tl.int64) * + seq_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + - query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_head_idx[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) tl.store( segm_output_ptr + segm_output_offset, acc, - mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + mask=dim_mask[None, :] & head_mask[:, None], ) - segm_offset = (query_offset_0.to(tl.int64) * + segm_offset = (seq_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + - query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx) - tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) + query_head_idx * NUM_SEGMENTS_PER_SEQ + segm_idx) + tl.store(segm_max_ptr + segm_offset, M, mask=head_mask) tl.store(segm_expsum_ptr + segm_offset, L, - mask=query_mask_0 & query_mask_1) + mask=head_mask) @triton.jit def reduce_segments( - output_ptr, # [num_tokens, num_query_heads, head_size] + output_ptr, # [num_seqs, num_query_heads, head_size] segm_output_ptr, - #[num_tokens, num_query_heads, max_num_segments, head_size] - segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] - segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] + #[num_seqs, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_seqs, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_seqs, num_query_heads, max_num_segments] seq_lens_ptr, # [num_seqs] - num_seqs, # int num_query_heads: tl.constexpr, # int output_stride_0: tl.int64, # int output_stride_1: tl.int64, # int, should be equal to head_size @@ -503,16 +473,11 @@ def reduce_segments( BLOCK_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int ): - query_token_idx = tl.program_id(0) + seq_idx = tl.program_id(0) query_head_idx = tl.program_id(1) - seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs, - BLOCK_Q, False) - # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) @@ -528,7 +493,7 @@ def reduce_segments( 0).to(tl.int1) # load segment maxima - segm_offset = (query_token_idx.to(tl.int64) * + segm_offset = (seq_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + query_head_idx * NUM_SEGMENTS_PER_SEQ + tl.arange(0, NUM_SEGMENTS_PER_SEQ)) @@ -546,7 +511,7 @@ def reduce_segments( # load, rescale, and add segment attention outputs segm_output_offset = ( - query_token_idx.to(tl.int64) * + seq_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + @@ -562,7 +527,7 @@ def reduce_segments( acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) # write result - output_offset = (query_token_idx * output_stride_0 + + output_offset = (seq_idx * output_stride_0 + query_head_idx * output_stride_1 + tl.arange(0, HEAD_SIZE_PADDED)) tl.store(output_ptr + output_offset, acc, mask=dim_mask) @@ -587,6 +552,7 @@ def unified_attention( v_descale, alibi_slopes=None, ): + assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -663,7 +629,8 @@ def unified_attention( else: # for initial version, NUM_SEGMENTS = 16 is chosen as a default # value that showed good performance in tests - NUM_SEGMENTS = 16 + NUM_SEGMENTS = 64 + num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16) segm_output = torch.empty( q.shape[0], @@ -687,9 +654,8 @@ def unified_attention( dtype=torch.float32, device=q.device, ) - kernel_unified_attention_3d[( - total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( + num_seqs, num_kv_heads, NUM_SEGMENTS)]( segm_output_ptr=segm_output, segm_max_ptr=segm_max, segm_expsum_ptr=segm_expsum, @@ -705,6 +671,7 @@ def unified_attention( softcap=softcap, num_query_heads=num_query_heads, num_queries_per_kv=num_queries_per_kv, + num_queries_per_kv_padded=num_queries_per_kv_padded, block_table_stride=block_table.stride(0), query_stride_0=q.stride(0), query_stride_1=q.stride(1), @@ -722,20 +689,14 @@ def unified_attention( stride_v_cache_1=v.stride(1), stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), - query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, - num_seqs=num_seqs, - BLOCK_M=BLOCK_M, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, ) - - reduce_segments[(q.shape[0], num_query_heads)]( + reduce_segments[(num_seqs, num_query_heads)]( output_ptr=out, segm_output_ptr=segm_output, segm_max_ptr=segm_max, segm_expsum_ptr=segm_expsum, seq_lens_ptr=seqused_k, - num_seqs=num_seqs, num_query_heads=num_query_heads, output_stride_0=out.stride(0), output_stride_1=out.stride(1), @@ -743,7 +704,5 @@ def unified_attention( BLOCK_SIZE=block_size, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, ) From 41a125438bb84b3cdb2b43c06d7c60e3dc039b87 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Fri, 18 Jul 2025 08:28:29 -0400 Subject: [PATCH 03/16] added reorder_batch method Signed-off-by: Jan van Lunteren --- vllm/v1/attention/backends/triton_attn.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 709adac2aa5d..7c992a501516 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -19,8 +19,11 @@ from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) + CommonAttentionMetadata, + reorder_batch_to_split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.worker.gpu_input_batch import InputBatch logger = init_logger(__name__) @@ -75,6 +78,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self.attention_chunk_size = getattr(vllm_config.scheduler_config, 'attention_chunk_size', None) + def reorder_batch(self, input_batch: InputBatch, + scheduler_output: SchedulerOutput) -> bool: + return reorder_batch_to_split_decodes_and_prefills(input_batch, + scheduler_output, + decode_threshold=1) + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> TritonAttentionMetadata: From 6109a140dbe0372774dac4c5de9fe386205d56fc Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Fri, 18 Jul 2025 08:53:26 -0400 Subject: [PATCH 04/16] added tiling to support large non-power-of-2 block sizes Signed-off-by: Jan van Lunteren --- .../attention/ops/triton_unified_attention.py | 194 +++++++++--------- 1 file changed, 102 insertions(+), 92 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index cd5cec12b404..ce7286d6023e 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -67,6 +67,7 @@ def kernel_unified_attention_2d( output_stride_0: tl.int64, # int output_stride_1: tl.int64, # int, should be equal to head_size BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int must be power of 2 HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool @@ -107,6 +108,7 @@ def kernel_unified_attention_2d( offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos @@ -153,31 +155,32 @@ def kernel_unified_attention_2d( # actual sequence length max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) - # calculate the number of tiles (blocks) that need to be processed to - # cover the longest sequence prefix (due to causal masking, blocks beyond + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond # this prefix can be skipped) - num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) # iterate through tiles - for j in range(0, num_blocks): + for j in range(0, num_tiles): + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + + seq_offset // BLOCK_SIZE).to(tl.int64) - offs_n = tl.arange(0, BLOCK_SIZE) - - v_offset = (physical_block_idx * stride_v_cache_0 + + v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 + offs_d[None, :] * stride_v_cache_3 + - offs_n[:, None] * stride_v_cache_1) + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) - k_offset = (physical_block_idx * stride_k_cache_0 + + k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 + offs_d[:, None] * stride_k_cache_3 + - offs_n[None, :] * stride_k_cache_1) + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) - # K : (HEAD_SIZE, BLOCK_SIZE) + # K : (HEAD_SIZE, TILE_SIZE) K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None], + mask=dim_mask[:, None] & tile_mask[None, :], other=0.0) if K_load.dtype.is_fp8(): @@ -188,9 +191,9 @@ def kernel_unified_attention_2d( else: K = K_load - # V : (BLOCK_SIZE, HEAD_SIZE) + # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :], + mask=dim_mask[None, :] & tile_mask[:, None], other=0.0) if V_load.dtype.is_fp8(): @@ -201,13 +204,10 @@ def kernel_unified_attention_2d( else: V = V_load - seq_offset = j * BLOCK_SIZE + offs_n - seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - # S : (BLOCK_M, BLOCK_SIZE) - S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) - + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) S += scale * tl.dot(Q, K) if USE_SOFTCAP: @@ -226,11 +226,12 @@ def kernel_unified_attention_2d( # compute running maximum # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_M, BLOCK_SIZE) + # P : (BLOCK_M, TILE_SIZE) P = tl.exp(S - m_j[:, None]) # l_j : (BLOCK_M,) @@ -281,11 +282,12 @@ def kernel_unified_attention_3d( softcap, # float32 num_query_heads: tl.constexpr, # int num_queries_per_kv: tl.constexpr, # int - num_queries_per_kv_padded: tl.constexpr, # int + num_queries_per_kv_padded: tl.constexpr, # int block_table_stride: tl.int64, # int query_stride_0: tl.int64, # int query_stride_1: tl.int64, # int, should be equal to head_size BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool @@ -309,17 +311,19 @@ def kernel_unified_attention_3d( seq_len = tl.load(seq_lens_ptr + seq_idx) # number of segments for this particular sequence - num_segments = NUM_SEGMENTS_PER_SEQ - blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + num_tiles = cdiv_fn(seq_len, TILE_SIZE) + tiles_per_segment = cdiv_fn(num_tiles, NUM_SEGMENTS_PER_SEQ) - if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len: + if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: return offs_d = tl.arange(0, HEAD_SIZE_PADDED) - offs_n = tl.arange(0, BLOCK_SIZE) + offs_t = tl.arange(0, TILE_SIZE) - query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(0, num_queries_per_kv_padded) - query_offset = seq_idx * query_stride_0 + query_head_idx[:, None] * query_stride_1 + offs_d[None, :] + query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange( + 0, num_queries_per_kv_padded) + query_offset = (seq_idx * query_stride_0 + + query_head_idx[:, None] * query_stride_1 + offs_d[None, :]) head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv head_mask = head_mask & (query_head_idx < num_query_heads) @@ -346,28 +350,30 @@ def kernel_unified_attention_3d( mask=head_mask, other=0.0) - num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) - # iterate through tiles within current segment for j in range( - segm_idx * blocks_per_segment, - min((segm_idx + 1) * blocks_per_segment, num_blocks), + segm_idx * tiles_per_segment, + min((segm_idx + 1) * tiles_per_segment, num_tiles), ): - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < seq_len - v_offset = (physical_block_idx * stride_v_cache_0 + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + + seq_offset // BLOCK_SIZE).to(tl.int64) + + v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 + offs_d[None, :] * stride_v_cache_3 + - offs_n[:, None] * stride_v_cache_1) + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) - k_offset = (physical_block_idx * stride_k_cache_0 + + k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 + offs_d[:, None] * stride_k_cache_3 + - offs_n[None, :] * stride_k_cache_1) + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) - # K : (HEAD_SIZE, BLOCK_SIZE) + # K : (HEAD_SIZE, TILE_SIZE) K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None], + mask=dim_mask[:, None] & tile_mask[None, :], other=0.0) if K_load.dtype.is_fp8(): @@ -378,9 +384,9 @@ def kernel_unified_attention_3d( else: K = K_load - # V : (BLOCK_SIZE, HEAD_SIZE) + # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :], + mask=dim_mask[None, :] & tile_mask[:, None], other=0.0) if V_load.dtype.is_fp8(): @@ -391,11 +397,11 @@ def kernel_unified_attention_3d( else: V = V_load - seq_offset = j * BLOCK_SIZE + offs_n seq_mask = seq_offset[None, :] < seq_len - # S : (num_queries_per_kv_padded, BLOCK_SIZE) - S = tl.zeros(shape=(num_queries_per_kv_padded, BLOCK_SIZE), dtype=tl.float32) + # S : (num_queries_per_kv_padded, TILE_SIZE) + S = tl.zeros(shape=(num_queries_per_kv_padded, TILE_SIZE), + dtype=tl.float32) S += scale * tl.dot(Q, K) context_len = seq_len - 1 @@ -403,7 +409,8 @@ def kernel_unified_attention_3d( if USE_SOFTCAP: S = apply_softcap(S, softcap) - S = tl.where(head_mask[:, None] & seq_mask, S, float("-inf")).to(tl.float32) + S = tl.where(head_mask[:, None] & seq_mask, S, + float("-inf")).to(tl.float32) if SLIDING_WINDOW > 0: S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, @@ -420,7 +427,7 @@ def kernel_unified_attention_3d( # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (num_queries_per_kv_padded, BLOCK_SIZE,) + # P : (num_queries_per_kv_padded, TILE_SIZE,) P = tl.exp(S - m_j[:, None]) # l_j : (num_queries_per_kv_padded,) @@ -453,9 +460,7 @@ def kernel_unified_attention_3d( (num_query_heads * NUM_SEGMENTS_PER_SEQ) + query_head_idx * NUM_SEGMENTS_PER_SEQ + segm_idx) tl.store(segm_max_ptr + segm_offset, M, mask=head_mask) - tl.store(segm_expsum_ptr + segm_offset, - L, - mask=head_mask) + tl.store(segm_expsum_ptr + segm_offset, L, mask=head_mask) @triton.jit @@ -470,7 +475,7 @@ def reduce_segments( output_stride_0: tl.int64, # int output_stride_1: tl.int64, # int, should be equal to head_size block_table_stride: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int @@ -483,10 +488,10 @@ def reduce_segments( # number of segments for this particular sequence num_segments = NUM_SEGMENTS_PER_SEQ - blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) # create masks for subsequent loads - act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE) + act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE) segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32) dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, @@ -583,6 +588,9 @@ def unified_attention( # = floor(q.shape[0] / BLOCK_Q) + num_seqs total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + TILE_SIZE_PREFILL = 16 + TILE_SIZE_DECODE = 16 + # if batch contains a prefill if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: kernel_unified_attention_2d[( @@ -608,6 +616,7 @@ def unified_attention( output_stride_0=out.stride(0), output_stride_1=out.stride(1), BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_PREFILL, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, @@ -629,8 +638,9 @@ def unified_attention( else: # for initial version, NUM_SEGMENTS = 16 is chosen as a default # value that showed good performance in tests - NUM_SEGMENTS = 64 - num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16) + NUM_SEGMENTS = 16 + num_queries_per_kv_padded = max( + triton.next_power_of_2(num_queries_per_kv), 16) segm_output = torch.empty( q.shape[0], @@ -654,43 +664,43 @@ def unified_attention( dtype=torch.float32, device=q.device, ) - kernel_unified_attention_3d[( - num_seqs, num_kv_heads, NUM_SEGMENTS)]( - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, - query_ptr=q, - key_cache_ptr=k, - value_cache_ptr=v, - block_tables_ptr=block_table, - seq_lens_ptr=seqused_k, - alibi_slopes_ptr=alibi_slopes, - scale=softmax_scale, - k_scale=k_descale, - v_scale=v_descale, - softcap=softcap, - num_query_heads=num_query_heads, - num_queries_per_kv=num_queries_per_kv, - num_queries_per_kv_padded=num_queries_per_kv_padded, - block_table_stride=block_table.stride(0), - query_stride_0=q.stride(0), - query_stride_1=q.stride(1), - BLOCK_SIZE=block_size, - HEAD_SIZE=head_size, - HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - USE_ALIBI_SLOPES=use_alibi_slopes, - USE_SOFTCAP=(softcap > 0), - SLIDING_WINDOW=(1 + window_size[0]), - stride_k_cache_0=k.stride(0), - stride_k_cache_1=k.stride(1), - stride_k_cache_2=k.stride(2), - stride_k_cache_3=k.stride(3), - stride_v_cache_0=v.stride(0), - stride_v_cache_1=v.stride(1), - stride_v_cache_2=v.stride(2), - stride_v_cache_3=v.stride(3), - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, - ) + kernel_unified_attention_3d[(num_seqs, num_kv_heads, NUM_SEGMENTS)]( + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + num_queries_per_kv_padded=num_queries_per_kv_padded, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + ) reduce_segments[(num_seqs, num_query_heads)]( output_ptr=out, segm_output_ptr=segm_output, @@ -701,7 +711,7 @@ def unified_attention( output_stride_0=out.stride(0), output_stride_1=out.stride(1), block_table_stride=block_table.stride(0), - BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, From 343ed93c592244c3328cd690ec752126e60ef626 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Fri, 18 Jul 2025 08:56:01 -0400 Subject: [PATCH 05/16] formatting Signed-off-by: Jan van Lunteren --- vllm/v1/attention/backends/triton_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 7c992a501516..608a33a4d171 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -18,11 +18,11 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, + reorder_batch_to_split_decodes_and_prefills) from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.gpu_input_batch import InputBatch logger = init_logger(__name__) From a96e2295893f0904e6061bd2484078dde701c151 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Fri, 18 Jul 2025 11:59:12 -0400 Subject: [PATCH 06/16] updated parameters Signed-off-by: Jan van Lunteren --- vllm/attention/ops/triton_unified_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index ce7286d6023e..50ed4ab8cc16 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -588,8 +588,8 @@ def unified_attention( # = floor(q.shape[0] / BLOCK_Q) + num_seqs total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs - TILE_SIZE_PREFILL = 16 - TILE_SIZE_DECODE = 16 + TILE_SIZE_PREFILL = 32 + TILE_SIZE_DECODE = 32 # if batch contains a prefill if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: From 6a375ee4e8cdfa37abad858b48cda0881ff3aace Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Mon, 21 Jul 2025 15:40:15 -0400 Subject: [PATCH 07/16] removed unneeded files Signed-off-by: Jan van Lunteren --- vllm/v1/attention/backends/tmp.txt | 382 ----------------- vllm/v1/attention/backends/triton_attn_new.py | 403 ------------------ vllm/v1/attention/backends/triton_attn_org.py | 394 ----------------- 3 files changed, 1179 deletions(-) delete mode 100644 vllm/v1/attention/backends/tmp.txt delete mode 100644 vllm/v1/attention/backends/triton_attn_new.py delete mode 100644 vllm/v1/attention/backends/triton_attn_org.py diff --git a/vllm/v1/attention/backends/tmp.txt b/vllm/v1/attention/backends/tmp.txt deleted file mode 100644 index 485a866c5ea7..000000000000 --- a/vllm/v1/attention/backends/tmp.txt +++ /dev/null @@ -1,382 +0,0 @@ -On branch jvl-hybrid-support -Your branch is up to date with 'origin/jvl-hybrid-support'. - -All conflicts fixed but you are still merging. - (use "git commit" to conclude merge) - -Changes to be committed: - modified: ../../../../.buildkite/scripts/hardware_ci/run-amd-test.sh - modified: ../../../../.buildkite/scripts/hardware_ci/run-cpu-test.sh - modified: ../../../../.buildkite/test-pipeline.yaml - modified: ../../../../.github/CODEOWNERS - modified: ../../../../.github/ISSUE_TEMPLATE/750-RFC.yml - modified: ../../../../.github/mergify.yml - modified: ../../../../RELEASE.md - new file: ../../../../benchmarks/auto_tune/README.md - renamed: ../../../../benchmarks/auto_tune.sh -> ../../../../benchmarks/auto_tune/auto_tune.sh - modified: ../../../../benchmarks/kernels/benchmark_moe.py - modified: ../../../../benchmarks/kernels/benchmark_moe_align_block_size.py - modified: ../../../../benchmarks/kernels/benchmark_moe_permute_unpermute.py - new file: ../../../../benchmarks/kv_cache/benchmark_block_pool.py - modified: ../../../../csrc/cpu/shm.cpp - modified: ../../../../csrc/moe/moe_align_sum_kernels.cu - modified: ../../../../csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu - modified: ../../../../csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh - modified: ../../../../csrc/quantization/cutlass_w8a8/moe/moe_data.cu - modified: ../../../../csrc/torch_bindings.cpp - modified: ../../../../docker/Dockerfile - modified: ../../../../docker/Dockerfile.xpu - modified: ../../../../docs/configuration/model_resolution.md - modified: ../../../../docs/design/v1/metrics.md - modified: ../../../../docs/features/lora.md - modified: ../../../../docs/features/multimodal_inputs.md - modified: ../../../../docs/features/quantization/bitblas.md - modified: ../../../../docs/features/quantization/fp8.md - modified: ../../../../docs/features/quantization/int4.md - modified: ../../../../docs/features/quantization/int8.md - modified: ../../../../docs/features/tool_calling.md - modified: ../../../../docs/getting_started/installation/cpu.md - modified: ../../../../docs/models/extensions/tensorizer.md - modified: ../../../../docs/models/pooling_models.md - modified: ../../../../docs/models/supported_models.md - modified: ../../../../docs/usage/v1_guide.md - modified: ../../../../examples/offline_inference/basic/classify.py - modified: ../../../../examples/offline_inference/basic/embed.py - modified: ../../../../examples/offline_inference/basic/score.py - modified: ../../../../examples/offline_inference/embed_jina_embeddings_v3.py - modified: ../../../../examples/offline_inference/embed_matryoshka_fy.py - modified: ../../../../examples/offline_inference/neuron_eagle.py - modified: ../../../../examples/offline_inference/neuron_speculation.py - modified: ../../../../examples/offline_inference/prithvi_geospatial_mae.py - modified: ../../../../examples/offline_inference/qwen3_reranker.py - new file: ../../../../examples/online_serving/elastic_ep/bench.sh - new file: ../../../../examples/online_serving/elastic_ep/scale.py - new file: ../../../../examples/online_serving/elastic_ep/serve_deepseek_v2.sh - modified: ../../../../pyproject.toml - modified: ../../../../requirements/cpu.txt - modified: ../../../../tests/basic_correctness/test_basic_correctness.py - modified: ../../../../tests/basic_correctness/test_preemption.py - modified: ../../../../tests/conftest.py - modified: ../../../../tests/core/test_num_computed_tokens_update.py - modified: ../../../../tests/core/test_serialization.py - modified: ../../../../tests/core/utils.py - modified: ../../../../tests/detokenizer/test_stop_reason.py - modified: ../../../../tests/detokenizer/test_stop_strings.py - modified: ../../../../tests/entrypoints/llm/test_accuracy.py - modified: ../../../../tests/entrypoints/openai/test_cli_args.py - deleted: ../../../../tests/kernels/attention/test_blocksparse_attention.py - modified: ../../../../tests/kernels/attention/test_flashinfer.py - modified: ../../../../tests/kernels/attention/test_rocm_attention_selector.py - modified: ../../../../tests/kernels/moe/test_cutlass_moe.py - modified: ../../../../tests/kernels/moe/test_moe_align_block_size.py - modified: ../../../../tests/lora/conftest.py - modified: ../../../../tests/lora/test_llama_tp.py - modified: ../../../../tests/lora/test_peft_helper.py - modified: ../../../../tests/metrics/test_metrics.py - modified: ../../../../tests/model_executor/test_guided_processors.py - modified: ../../../../tests/model_executor/test_model_load_with_params.py - modified: ../../../../tests/models/language/generation/test_hybrid.py - modified: ../../../../tests/models/language/generation/test_mistral.py - modified: ../../../../tests/models/language/pooling/mteb_utils.py - modified: ../../../../tests/models/language/pooling/test_gritlm.py - modified: ../../../../tests/models/language/pooling/test_jina.py - modified: ../../../../tests/models/language/pooling/test_nomic_max_model_len.py - modified: ../../../../tests/models/language/pooling/test_truncation_control.py - modified: ../../../../tests/models/multimodal/generation/test_common.py - new file: ../../../../tests/models/multimodal/generation/test_maverick.py - modified: ../../../../tests/models/multimodal/generation/test_pixtral.py - modified: ../../../../tests/models/multimodal/generation/test_whisper.py - modified: ../../../../tests/models/multimodal/generation/vlm_utils/core.py - modified: ../../../../tests/models/multimodal/pooling/test_dse_qwen2_vl.py - modified: ../../../../tests/models/multimodal/pooling/test_jinavl_reranker.py - modified: ../../../../tests/models/quantization/test_modelopt.py - modified: ../../../../tests/models/quantization/test_nvfp4.py - modified: ../../../../tests/models/registry.py - modified: ../../../../tests/models/test_initialization.py - modified: ../../../../tests/models/test_registry.py - modified: ../../../../tests/models/test_transformers.py - modified: ../../../../tests/multimodal/test_video.py - modified: ../../../../tests/multimodal/utils.py - modified: ../../../../tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py - modified: ../../../../tests/prefix_caching/test_disable_sliding_window.py - modified: ../../../../tests/prefix_caching/test_prefix_caching.py - modified: ../../../../tests/quantization/test_gptq_dynamic.py - new file: ../../../../tests/quantization/test_modelopt.py - modified: ../../../../tests/quantization/test_quark.py - modified: ../../../../tests/quantization/test_register_quantization_config.py - modified: ../../../../tests/samplers/test_ignore_eos.py - modified: ../../../../tests/samplers/test_logits_processor.py - modified: ../../../../tests/samplers/test_logprobs.py - modified: ../../../../tests/samplers/test_no_bad_words.py - deleted: ../../../../tests/samplers/test_rejection_sampler.py - modified: ../../../../tests/samplers/test_seeded_generate.py - deleted: ../../../../tests/samplers/test_typical_acceptance_sampler.py - deleted: ../../../../tests/spec_decode/conftest.py - deleted: ../../../../tests/spec_decode/e2e/__init__.py - deleted: ../../../../tests/spec_decode/e2e/conftest.py - deleted: ../../../../tests/spec_decode/e2e/test_compatibility.py - deleted: ../../../../tests/spec_decode/e2e/test_eagle_correctness.py - deleted: ../../../../tests/spec_decode/e2e/test_integration.py - deleted: ../../../../tests/spec_decode/e2e/test_integration_dist_tp2.py - deleted: ../../../../tests/spec_decode/e2e/test_integration_dist_tp4.py - deleted: ../../../../tests/spec_decode/e2e/test_logprobs.py - deleted: ../../../../tests/spec_decode/e2e/test_medusa_correctness.py - deleted: ../../../../tests/spec_decode/e2e/test_mlp_correctness.py - deleted: ../../../../tests/spec_decode/e2e/test_mtp_correctness.py - deleted: ../../../../tests/spec_decode/e2e/test_multistep_correctness.py - deleted: ../../../../tests/spec_decode/e2e/test_ngram_correctness.py - deleted: ../../../../tests/spec_decode/e2e/test_seed.py - deleted: ../../../../tests/spec_decode/test_batch_expansion.py - deleted: ../../../../tests/spec_decode/test_dynamic_spec_decode.py - deleted: ../../../../tests/spec_decode/test_memory_usage.py - deleted: ../../../../tests/spec_decode/test_metrics.py - deleted: ../../../../tests/spec_decode/test_multi_step_worker.py - deleted: ../../../../tests/spec_decode/test_ngram_worker.py - deleted: ../../../../tests/spec_decode/test_scorer.py - deleted: ../../../../tests/spec_decode/test_spec_decode_worker.py - deleted: ../../../../tests/spec_decode/test_utils.py - deleted: ../../../../tests/spec_decode/utils.py - modified: ../../../../tests/test_sequence.py - modified: ../../../../tests/test_utils.py - modified: ../../../../tests/tokenization/test_detokenize.py - new file: ../../../../tests/tokenization/test_do_lower_case.py - new file: ../../../../tests/tool_use/test_glm4_moe_tool_parser.py - modified: ../../../../tests/v1/core/test_kv_cache_utils.py - modified: ../../../../tests/v1/core/test_prefix_caching.py - modified: ../../../../tests/v1/core/test_scheduler_e2e.py - modified: ../../../../tests/v1/core/test_specialized_manager.py - modified: ../../../../tests/v1/engine/test_async_llm.py - modified: ../../../../tests/v1/engine/test_llm_engine.py - modified: ../../../../tests/v1/kv_connector/unit/test_nixl_connector.py - renamed: ../../../../tests/v1/executor/test_multiproc_executor.py -> ../../../../tests/v1/kv_connector/unit/test_output_aggreagator.py - modified: ../../../../tests/v1/metrics/test_ray_metrics.py - modified: ../../../../tests/v1/sample/test_logprobs.py - modified: ../../../../tests/v1/sample/test_sampling_params_e2e.py - modified: ../../../../tests/v1/test_async_llm_dp.py - modified: ../../../../tests/v1/test_oracle.py - modified: ../../../../tests/v1/tpu/test_pallas.py - new file: ../../../../tools/ep_kernels/elastic_ep/eep_nvshmem.patch - new file: ../../../../tools/ep_kernels/elastic_ep/install_eep_libraries.sh - modified: ../../../../tools/mypy.sh - modified: ../../../_custom_ops.py - modified: ../../../assets/video.py - modified: ../../../attention/backends/abstract.py - deleted: ../../../attention/backends/blocksparse_attn.py - modified: ../../../attention/backends/differential_flash_attn.py - modified: ../../../attention/backends/dual_chunk_flash_attn.py - modified: ../../../attention/backends/flash_attn.py - modified: ../../../attention/backends/flashinfer.py - modified: ../../../attention/backends/flashmla.py - modified: ../../../attention/backends/mla/common.py - modified: ../../../attention/backends/rocm_aiter_mla.py - modified: ../../../attention/backends/rocm_flash_attn.py - modified: ../../../attention/backends/triton_mla.py - modified: ../../../attention/backends/xformers.py - modified: ../../../attention/layer.py - deleted: ../../../attention/ops/blocksparse_attention/__init__.py - deleted: ../../../attention/ops/blocksparse_attention/blocksparse_attention_kernel.py - deleted: ../../../attention/ops/blocksparse_attention/interface.py - deleted: ../../../attention/ops/blocksparse_attention/utils.py - modified: ../../../attention/ops/rocm_aiter_mla.py - modified: ../../../attention/selector.py - modified: ../../../benchmarks/serve.py - modified: ../../../config.py - modified: ../../../distributed/device_communicators/cpu_communicator.py - modified: ../../../distributed/eplb/eplb_state.py - modified: ../../../distributed/eplb/rebalance_execute.py - modified: ../../../distributed/kv_transfer/kv_connector/utils.py - modified: ../../../distributed/kv_transfer/kv_connector/v1/base.py - modified: ../../../distributed/parallel_state.py - modified: ../../../engine/arg_utils.py - modified: ../../../engine/llm_engine.py - modified: ../../../engine/metrics.py - modified: ../../../engine/metrics_types.py - modified: ../../../engine/output_processor/multi_step.py - modified: ../../../engine/protocol.py - modified: ../../../entrypoints/llm.py - modified: ../../../entrypoints/openai/api_server.py - modified: ../../../entrypoints/openai/cli_args.py - modified: ../../../entrypoints/openai/protocol.py - modified: ../../../entrypoints/openai/serving_classification.py - modified: ../../../entrypoints/openai/serving_embedding.py - modified: ../../../entrypoints/openai/serving_engine.py - modified: ../../../entrypoints/openai/serving_pooling.py - modified: ../../../entrypoints/openai/serving_score.py - modified: ../../../entrypoints/openai/tool_parsers/__init__.py - new file: ../../../entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py - modified: ../../../envs.py - modified: ../../../executor/executor_base.py - modified: ../../../executor/uniproc_executor.py - modified: ../../../lora/layers.py - modified: ../../../lora/models.py - modified: ../../../lora/peft_helper.py - modified: ../../../lora/punica_wrapper/punica_base.py - modified: ../../../lora/punica_wrapper/punica_gpu.py - modified: ../../../lora/punica_wrapper/punica_tpu.py - modified: ../../../lora/punica_wrapper/utils.py - modified: ../../../lora/utils.py - modified: ../../../lora/worker_manager.py - renamed: ../../../../tests/spec_decode/__init__.py -> ../../../mocks/__init__.py - new file: ../../../mocks/mock_nixl_connector.py - modified: ../../../model_executor/layers/fused_moe/batched_deep_gemm_moe.py - modified: ../../../model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py - modified: ../../../model_executor/layers/fused_moe/config.py - modified: ../../../model_executor/layers/fused_moe/cutlass_moe.py - modified: ../../../model_executor/layers/fused_moe/deep_gemm_moe.py - modified: ../../../model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py - modified: ../../../model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py - new file: ../../../model_executor/layers/fused_moe/flashinfer_cutlass_moe.py - new file: ../../../model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - modified: ../../../model_executor/layers/fused_moe/fused_batched_moe.py - modified: ../../../model_executor/layers/fused_moe/fused_moe.py - modified: ../../../model_executor/layers/fused_moe/layer.py - modified: ../../../model_executor/layers/fused_moe/modular_kernel.py - modified: ../../../model_executor/layers/fused_moe/moe_align_block_size.py - modified: ../../../model_executor/layers/fused_moe/pplx_prepare_finalize.py - modified: ../../../model_executor/layers/fused_moe/prepare_finalize.py - modified: ../../../model_executor/layers/fused_moe/triton_deep_gemm_moe.py - modified: ../../../model_executor/layers/fused_moe/utils.py - modified: ../../../model_executor/layers/mamba/mamba_mixer2.py - modified: ../../../model_executor/layers/pooler.py - modified: ../../../model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py - modified: ../../../model_executor/layers/quantization/fp8.py - modified: ../../../model_executor/layers/quantization/modelopt.py - deleted: ../../../model_executor/layers/rejection_sampler.py - modified: ../../../model_executor/layers/sampler.py - deleted: ../../../model_executor/layers/spec_decode_base_sampler.py - deleted: ../../../model_executor/layers/typical_acceptance_sampler.py - modified: ../../../model_executor/model_loader/utils.py - modified: ../../../model_executor/models/adapters.py - modified: ../../../model_executor/models/bailing_moe.py - modified: ../../../model_executor/models/bamba.py - modified: ../../../model_executor/models/bart.py - modified: ../../../model_executor/models/bert.py - modified: ../../../model_executor/models/deepseek_v2.py - deleted: ../../../model_executor/models/eagle.py - modified: ../../../model_executor/models/ernie45_moe.py - new file: ../../../model_executor/models/exaone4.py - modified: ../../../model_executor/models/falcon_h1.py - modified: ../../../model_executor/models/gemma.py - new file: ../../../model_executor/models/glm4_moe.py - new file: ../../../model_executor/models/glm4_moe_mtp.py - modified: ../../../model_executor/models/gpt2.py - modified: ../../../model_executor/models/granitemoe.py - modified: ../../../model_executor/models/granitemoehybrid.py - modified: ../../../model_executor/models/granitemoeshared.py - modified: ../../../model_executor/models/gritlm.py - modified: ../../../model_executor/models/grok1.py - modified: ../../../model_executor/models/hunyuan_v1_moe.py - modified: ../../../model_executor/models/interfaces.py - modified: ../../../model_executor/models/internlm2.py - modified: ../../../model_executor/models/jamba.py - modified: ../../../model_executor/models/jina_vl.py - modified: ../../../model_executor/models/llama.py - modified: ../../../model_executor/models/mamba2.py - modified: ../../../model_executor/models/modernbert.py - modified: ../../../model_executor/models/nemotron_h.py - deleted: ../../../model_executor/models/phi3_small.py - modified: ../../../model_executor/models/qwen2.py - modified: ../../../model_executor/models/qwen2_rm.py - modified: ../../../model_executor/models/qwen3.py - modified: ../../../model_executor/models/registry.py - modified: ../../../model_executor/models/roberta.py - modified: ../../../model_executor/models/transformers.py - modified: ../../../model_executor/models/zamba2.py - modified: ../../../model_executor/pooling_metadata.py - modified: ../../../platforms/cpu.py - modified: ../../../platforms/cuda.py - modified: ../../../platforms/interface.py - modified: ../../../platforms/rocm.py - modified: ../../../platforms/tpu.py - modified: ../../../pooling_params.py - modified: ../../../reasoning/__init__.py - new file: ../../../reasoning/glm4_moe_reasoning_parser.py - modified: ../../../sequence.py - deleted: ../../../spec_decode/__init__.py - deleted: ../../../spec_decode/batch_expansion.py - deleted: ../../../spec_decode/draft_model_runner.py - deleted: ../../../spec_decode/interfaces.py - deleted: ../../../spec_decode/medusa_worker.py - deleted: ../../../spec_decode/metrics.py - deleted: ../../../spec_decode/mlp_speculator_worker.py - deleted: ../../../spec_decode/mqa_scorer.py - deleted: ../../../spec_decode/multi_step_worker.py - deleted: ../../../spec_decode/ngram_worker.py - deleted: ../../../spec_decode/proposer_worker_base.py - deleted: ../../../spec_decode/smaller_tp_proposer_worker.py - deleted: ../../../spec_decode/spec_decode_worker.py - deleted: ../../../spec_decode/target_model_runner.py - deleted: ../../../spec_decode/top1_proposer.py - deleted: ../../../spec_decode/util.py - modified: ../../../transformers_utils/config.py - modified: ../../../transformers_utils/configs/__init__.py - modified: ../../../transformers_utils/configs/eagle.py - new file: ../../../transformers_utils/configs/exaone4.py - modified: ../../../transformers_utils/configs/mistral.py - modified: ../../../transformers_utils/tokenizer.py - modified: ../../../utils/__init__.py - modified: ../../../utils/deep_gemm.py - new file: ../../../utils/flashinfer.py - modified: cpu_attn.py - modified: flash_attn.py - modified: flashinfer.py - modified: flex_attention.py - modified: mla/common.py - modified: mla/cutlass_mla.py - modified: mla/flashmla.py - modified: mla/rocm_aiter_mla.py - modified: mla/triton_mla.py - modified: pallas.py - modified: rocm_aiter_fa.py - modified: triton_attn.py - new file: triton_attn_new.py - new file: triton_attn_org.py - modified: utils.py - modified: ../../core/kv_cache_utils.py - modified: ../../core/single_type_kv_cache_manager.py - modified: ../../engine/__init__.py - modified: ../../engine/async_llm.py - modified: ../../engine/coordinator.py - modified: ../../engine/core.py - modified: ../../engine/core_client.py - modified: ../../engine/utils.py - modified: ../../executor/multiproc_executor.py - modified: ../../executor/ray_distributed_executor.py - modified: ../../kv_cache_interface.py - modified: ../../metrics/loggers.py - modified: ../../metrics/ray_wrappers.py - modified: ../../pool/metadata.py - modified: ../../worker/cpu_model_runner.py - modified: ../../worker/cpu_worker.py - modified: ../../worker/gpu_input_batch.py - modified: ../../worker/gpu_model_runner.py - modified: ../../worker/gpu_worker.py - modified: ../../worker/tpu_model_runner.py - modified: ../../worker/tpu_worker.py - modified: ../../../worker/model_runner_base.py - modified: ../../../worker/pooling_model_runner.py - modified: ../../../worker/worker.py - modified: ../../../worker/worker_base.py - -Untracked files: - (use "git add ..." to include in what will be committed) - ../../../../ShareGPT_V3_unfiltered_cleaned_split.json - ../../../../logs_unif/ - ../../../../ref.txt - ../../../../res_dec_128_16.txt - ../../../../res_prompt_128_32.txt - ../../../../res_prompt_256_16.txt - ../../../../res_prompt_32_16.txt - ../../../../run_benchmark_latency_7.py - ../../../../tests/test1.py - ../../../../tests/test2.py - ../../../../tmp.txt - ../../../attention/ops/backup/ - ../../../attention/ops/diff.txt - ../../../attention/ops/temp.py - ../../../attention/ops/triton_unified_attention_base.py - ../../../attention/ops/triton_unified_attention_new.py - ../../../attention/ops/triton_unified_attention_org.py - ../../../attention/ops/triton_unified_attention_working.py - tmp.txt - diff --git a/vllm/v1/attention/backends/triton_attn_new.py b/vllm/v1/attention/backends/triton_attn_new.py deleted file mode 100644 index 608a33a4d171..000000000000 --- a/vllm/v1/attention/backends/triton_attn_new.py +++ /dev/null @@ -1,403 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with PagedAttention and Triton prefix prefill.""" -from dataclasses import dataclass -from typing import Any, ClassVar, Optional - -import torch - -from vllm import _custom_ops as ops -from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) -from vllm.attention.ops.chunked_prefill_paged_decode import ( - chunked_prefill_paged_decode) -from vllm.attention.ops.paged_attn import PagedAttention -from vllm.attention.ops.triton_unified_attention import unified_attention -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.platforms import current_platform -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills) -from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.gpu_input_batch import InputBatch - -logger = init_logger(__name__) - - -@dataclass -class TritonAttentionMetadata: - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - num_actual_tokens: int # Number of tokens excluding padding. - max_query_len: int - query_start_loc: torch.Tensor - max_seq_len: int - seq_lens: torch.Tensor - block_table: torch.Tensor - slot_mapping: torch.Tensor - - # For cascade attention. - use_cascade: bool - common_prefix_len: int - cu_prefix_query_lens: Optional[torch.Tensor] - prefix_kv_lens: Optional[torch.Tensor] - suffix_kv_lens: Optional[torch.Tensor] - - # Optional aot scheduling - scheduler_metadata: Optional[torch.Tensor] = None - prefix_scheduler_metadata: Optional[torch.Tensor] = None - - -class TritonAttentionMetadataBuilder( - AttentionMetadataBuilder[TritonAttentionMetadata]): - full_cudagraph_supported: ClassVar[bool] = True - - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): - self.device = device - self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec - - model_config = vllm_config.model_config - self.num_heads_q = model_config.get_num_attention_heads( - vllm_config.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads( - vllm_config.parallel_config) - self.headdim = model_config.get_head_size() - - self.attention_chunk_size = getattr(vllm_config.scheduler_config, - 'attention_chunk_size', None) - - def reorder_batch(self, input_batch: InputBatch, - scheduler_output: SchedulerOutput) -> bool: - return reorder_batch_to_split_decodes_and_prefills(input_batch, - scheduler_output, - decode_threshold=1) - - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata - ) -> TritonAttentionMetadata: - attn_metadata = self.build(0, common_attn_metadata) - # When doing full graph capture, setting seq_lens to - # max_model_len will cause graph capture to be extremely - # slow, so here we set it to 1. - attn_metadata.seq_lens.fill_(1) - return attn_metadata - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> TritonAttentionMetadata: - num_actual_tokens = common_attn_metadata.num_actual_tokens - max_query_len = common_attn_metadata.max_query_len - - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) - query_start_loc = common_attn_metadata.query_start_loc - seq_lens = common_attn_metadata.seq_lens - block_table_tensor = common_attn_metadata.block_table_tensor - slot_mapping = common_attn_metadata.slot_mapping - - use_cascade = common_prefix_len > 0 - - if use_cascade: - cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device=self.device) - prefix_kv_lens = torch.tensor([common_prefix_len], - dtype=torch.int32, - device=self.device) - suffix_kv_lens = (common_attn_metadata.seq_lens_cpu - - common_prefix_len) - suffix_kv_lens = suffix_kv_lens.to(self.device) - else: - cu_prefix_query_lens = None - prefix_kv_lens = None - suffix_kv_lens = None - prefix_scheduler_metadata = None - - attn_metadata = TritonAttentionMetadata( - num_actual_tokens=num_actual_tokens, - max_query_len=max_query_len, - query_start_loc=query_start_loc, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=block_table_tensor, - slot_mapping=slot_mapping, - use_cascade=use_cascade, - common_prefix_len=common_prefix_len, - cu_prefix_query_lens=cu_prefix_query_lens, - prefix_kv_lens=prefix_kv_lens, - suffix_kv_lens=suffix_kv_lens, - prefix_scheduler_metadata=prefix_scheduler_metadata, - ) - return attn_metadata - - def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: - # Full CUDA Graph always supported - return True - - -class TritonAttentionBackend(AttentionBackend): - - accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") - - @staticmethod - def get_name() -> str: - return "TRITON_ATTN_VLLM_V1" - - @staticmethod - def get_impl_cls() -> type["TritonAttentionImpl"]: - return TritonAttentionImpl - - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return TritonAttentionMetadata - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (num_blocks, 2, block_size, num_kv_heads, head_size) - - @staticmethod - def use_cascade_attention(*args, **kwargs) -> bool: - return False - - @staticmethod - def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: - return TritonAttentionMetadataBuilder - - -class TritonAttentionImpl(AttentionImpl): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, - logits_soft_cap: Optional[float] = None, - attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[int] = None, - use_irope: bool = False, - ) -> None: - if blocksparse_params is not None: - raise ValueError( - "TritonAttention does not support block-sparse attention.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - if sliding_window is None: - self.sliding_window = (-1, -1) - else: - self.sliding_window = (sliding_window - 1, 0) - self.kv_cache_dtype = kv_cache_dtype - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0 - self.logits_soft_cap = logits_soft_cap - self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - - self.use_irope = use_irope - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - TritonAttentionBackend.validate_head_size(head_size) - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonAttentionImpl") - - self.fp8_dtype = current_platform.fp8_dtype() - self.force_prefill_decode_attn = \ - envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION - - def forward( - self, - layer: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention. - - Args: - query: shape = [num_tokens, num_heads, head_size] - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - assert output is not None, "Output tensor must be provided." - - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for TritonAttentionImpl") - - if attn_metadata is None: - # Profiling run. - return output - - assert attn_metadata.use_cascade is False - - # IMPORTANT! - # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in - # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead - # in this method. For example, `view` and `slice` (or `[:n]`) operations - # are surprisingly slow even in the case they do not invoke any GPU ops. - # Minimize the PyTorch ops in this method as much as possible. - # Whenever making a change in this method, please benchmark the - # performance to make sure it does not introduce any overhead. - - use_prefill_decode_attn = self.force_prefill_decode_attn - num_actual_tokens = attn_metadata.num_actual_tokens - - if use_prefill_decode_attn: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - else: - key_cache, value_cache = kv_cache.unbind(1) - - if self.kv_sharing_target_layer_name is None: - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - if use_prefill_decode_attn: - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - else: - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - if self.kv_cache_dtype.startswith("fp8"): - key_cache = key_cache.view(self.fp8_dtype) - value_cache = value_cache.view(self.fp8_dtype) - num_tokens, num_heads, head_size = query.shape - assert layer._q_scale == 1.0, \ - "A non 1.0 q_scale is not currently supported." - if not current_platform.is_rocm(): - # Skip Q quantization on ROCm, since dequantizing back to - # f32 in the attention kernel is not supported. - query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) - query = query.reshape((num_tokens, num_heads, head_size)) - - cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len - block_table = attn_metadata.block_table - - if use_prefill_decode_attn: - # Compute attention and update output up to `num_actual_tokens`. - chunked_prefill_paged_decode(query=query[:num_actual_tokens], - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], - output=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - key_cache=key_cache, - value_cache=value_cache, - block_table=block_table, - query_start_loc=cu_seqlens_q, - seq_lens=seqused_k, - max_seq_len=max_seqlen_k, - max_query_len=max_seqlen_q, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - sm_scale=self.scale) - - else: - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) - - unified_attention( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=seqused_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - q_descale=None, # Not supported - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - - return output diff --git a/vllm/v1/attention/backends/triton_attn_org.py b/vllm/v1/attention/backends/triton_attn_org.py deleted file mode 100644 index 79796ac14928..000000000000 --- a/vllm/v1/attention/backends/triton_attn_org.py +++ /dev/null @@ -1,394 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with PagedAttention and Triton prefix prefill.""" -from dataclasses import dataclass -from typing import Any, ClassVar, Optional - -import torch - -from vllm import _custom_ops as ops -from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) -from vllm.attention.ops.chunked_prefill_paged_decode import ( - chunked_prefill_paged_decode) -from vllm.attention.ops.paged_attn import PagedAttention -from vllm.attention.ops.triton_unified_attention import unified_attention -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.platforms import current_platform -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) -from vllm.v1.kv_cache_interface import AttentionSpec - -logger = init_logger(__name__) - - -@dataclass -class TritonAttentionMetadata: - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - num_actual_tokens: int # Number of tokens excluding padding. - max_query_len: int - query_start_loc: torch.Tensor - max_seq_len: int - seq_lens: torch.Tensor - block_table: torch.Tensor - slot_mapping: torch.Tensor - - # For cascade attention. - use_cascade: bool - common_prefix_len: int - cu_prefix_query_lens: Optional[torch.Tensor] - prefix_kv_lens: Optional[torch.Tensor] - suffix_kv_lens: Optional[torch.Tensor] - - # Optional aot scheduling - scheduler_metadata: Optional[torch.Tensor] = None - prefix_scheduler_metadata: Optional[torch.Tensor] = None - - -class TritonAttentionMetadataBuilder( - AttentionMetadataBuilder[TritonAttentionMetadata]): - full_cudagraph_supported: ClassVar[bool] = True - - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): - self.device = device - self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec - - model_config = vllm_config.model_config - self.num_heads_q = model_config.get_num_attention_heads( - vllm_config.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads( - vllm_config.parallel_config) - self.headdim = model_config.get_head_size() - - self.attention_chunk_size = getattr(vllm_config.scheduler_config, - 'attention_chunk_size', None) - - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata - ) -> TritonAttentionMetadata: - attn_metadata = self.build(0, common_attn_metadata) - # When doing full graph capture, setting seq_lens to - # max_model_len will cause graph capture to be extremely - # slow, so here we set it to 1. - attn_metadata.seq_lens.fill_(1) - return attn_metadata - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> TritonAttentionMetadata: - num_actual_tokens = common_attn_metadata.num_actual_tokens - max_query_len = common_attn_metadata.max_query_len - - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) - query_start_loc = common_attn_metadata.query_start_loc - seq_lens = common_attn_metadata.seq_lens - block_table_tensor = common_attn_metadata.block_table_tensor - slot_mapping = common_attn_metadata.slot_mapping - - use_cascade = common_prefix_len > 0 - - if use_cascade: - cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device=self.device) - prefix_kv_lens = torch.tensor([common_prefix_len], - dtype=torch.int32, - device=self.device) - suffix_kv_lens = (common_attn_metadata.seq_lens_cpu - - common_prefix_len) - suffix_kv_lens = suffix_kv_lens.to(self.device) - else: - cu_prefix_query_lens = None - prefix_kv_lens = None - suffix_kv_lens = None - prefix_scheduler_metadata = None - - attn_metadata = TritonAttentionMetadata( - num_actual_tokens=num_actual_tokens, - max_query_len=max_query_len, - query_start_loc=query_start_loc, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=block_table_tensor, - slot_mapping=slot_mapping, - use_cascade=use_cascade, - common_prefix_len=common_prefix_len, - cu_prefix_query_lens=cu_prefix_query_lens, - prefix_kv_lens=prefix_kv_lens, - suffix_kv_lens=suffix_kv_lens, - prefix_scheduler_metadata=prefix_scheduler_metadata, - ) - return attn_metadata - - def can_run_in_cudagraph( - self, common_attn_metadata: CommonAttentionMetadata) -> bool: - # Full CUDA Graph always supported - return True - - -class TritonAttentionBackend(AttentionBackend): - - accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") - - @staticmethod - def get_name() -> str: - return "TRITON_ATTN_VLLM_V1" - - @staticmethod - def get_impl_cls() -> type["TritonAttentionImpl"]: - return TritonAttentionImpl - - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return TritonAttentionMetadata - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) - - @staticmethod - def use_cascade_attention(*args, **kwargs) -> bool: - return False - - @staticmethod - def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: - return TritonAttentionMetadataBuilder - - -class TritonAttentionImpl(AttentionImpl): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, - logits_soft_cap: Optional[float] = None, - attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[int] = None, - use_irope: bool = False, - ) -> None: - if blocksparse_params is not None: - raise ValueError( - "TritonAttention does not support block-sparse attention.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - if sliding_window is None: - self.sliding_window = (-1, -1) - else: - self.sliding_window = (sliding_window - 1, 0) - self.kv_cache_dtype = kv_cache_dtype - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0 - self.logits_soft_cap = logits_soft_cap - self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - - self.use_irope = use_irope - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - TritonAttentionBackend.validate_head_size(head_size) - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonAttentionImpl") - - self.fp8_dtype = current_platform.fp8_dtype() - self.force_prefill_decode_attn = \ - envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION - - def forward( - self, - layer: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention. - - Args: - query: shape = [num_tokens, num_heads, head_size] - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - assert output is not None, "Output tensor must be provided." - - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for TritonAttentionImpl") - - if attn_metadata is None: - # Profiling run. - return output - - assert attn_metadata.use_cascade is False - - # IMPORTANT! - # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in - # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead - # in this method. For example, `view` and `slice` (or `[:n]`) operations - # are surprisingly slow even in the case they do not invoke any GPU ops. - # Minimize the PyTorch ops in this method as much as possible. - # Whenever making a change in this method, please benchmark the - # performance to make sure it does not introduce any overhead. - - use_prefill_decode_attn = self.force_prefill_decode_attn - num_actual_tokens = attn_metadata.num_actual_tokens - - if use_prefill_decode_attn: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - else: - key_cache, value_cache = kv_cache.unbind(0) - - if self.kv_sharing_target_layer_name is None: - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - if use_prefill_decode_attn: - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - else: - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - if self.kv_cache_dtype.startswith("fp8"): - key_cache = key_cache.view(self.fp8_dtype) - value_cache = value_cache.view(self.fp8_dtype) - num_tokens, num_heads, head_size = query.shape - assert layer._q_scale == 1.0, \ - "A non 1.0 q_scale is not currently supported." - if not current_platform.is_rocm(): - # Skip Q quantization on ROCm, since dequantizing back to - # f32 in the attention kernel is not supported. - query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) - query = query.reshape((num_tokens, num_heads, head_size)) - - cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len - block_table = attn_metadata.block_table - - if use_prefill_decode_attn: - # Compute attention and update output up to `num_actual_tokens`. - chunked_prefill_paged_decode(query=query[:num_actual_tokens], - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], - output=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - key_cache=key_cache, - value_cache=value_cache, - block_table=block_table, - query_start_loc=cu_seqlens_q, - seq_lens=seqused_k, - max_seq_len=max_seqlen_k, - max_query_len=max_seqlen_q, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - sm_scale=self.scale) - - else: - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) - - unified_attention( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=seqused_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - q_descale=None, # Not supported - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - - return output From c4f7d677590e140c2a22651b4cae73e3d4727b4f Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Tue, 29 Jul 2025 11:39:21 -0400 Subject: [PATCH 08/16] add prefill support back to split-kv kernel Signed-off-by: Jan van Lunteren --- .../attention/ops/triton_unified_attention.py | 234 +++++++++++------- 1 file changed, 142 insertions(+), 92 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 50ed4ab8cc16..d1829e50ba2a 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -208,6 +208,7 @@ def kernel_unified_attention_2d( # S : (BLOCK_M, TILE_SIZE) S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) + S += scale * tl.dot(Q, K) if USE_SOFTCAP: @@ -282,7 +283,6 @@ def kernel_unified_attention_3d( softcap, # float32 num_query_heads: tl.constexpr, # int num_queries_per_kv: tl.constexpr, # int - num_queries_per_kv_padded: tl.constexpr, # int block_table_stride: tl.int64, # int query_stride_0: tl.int64, # int query_stride_1: tl.int64, # int, should be equal to head_size @@ -301,62 +301,101 @@ def kernel_unified_attention_3d( stride_v_cache_1: tl.int64, # int stride_v_cache_2: tl.int64, # int stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int ): - seq_idx = tl.program_id(0) + q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) segm_idx = tl.program_id(2) + seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, + BLOCK_Q, True) + + q_block_start_idx = tl.load(query_start_len_ptr + + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) # number of segments for this particular sequence - num_tiles = cdiv_fn(seq_len, TILE_SIZE) - tiles_per_segment = cdiv_fn(num_tiles, NUM_SEGMENTS_PER_SEQ) + num_segments = NUM_SEGMENTS_PER_SEQ + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: return + offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) offs_t = tl.arange(0, TILE_SIZE) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv - query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange( - 0, num_queries_per_kv_padded) - query_offset = (seq_idx * query_stride_0 + - query_head_idx[:, None] * query_stride_1 + offs_d[None, :]) - - head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv - head_mask = head_mask & (query_head_idx < num_query_heads) + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + \ + offs_m % num_queries_per_kv + query_offset = (query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) - # Q : (num_queries_per_kv_padded, HEAD_SIZE_PADDED) + # Q : (BLOCK_M, HEAD_SIZE_PADDED) Q = tl.load( query_ptr + query_offset, - mask=dim_mask[None, :] & head_mask[:, None], + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], other=0.0, ) block_table_offset = seq_idx * block_table_stride - M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) - L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) - acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], - dtype=tl.float32) + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len # alibi slope for this head if USE_ALIBI_SLOPES: - alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx, - mask=head_mask, + alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, + mask=query_mask_1, other=0.0) + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( + BLOCK_M - 1) // num_queries_per_kv + 1 + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + # iterate through tiles within current segment for j in range( segm_idx * tiles_per_segment, min((segm_idx + 1) * tiles_per_segment, num_tiles), ): seq_offset = j * TILE_SIZE + offs_t - tile_mask = seq_offset < seq_len + tile_mask = seq_offset < max_seq_prefix_len physical_block_idx = tl.load(block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE).to(tl.int64) @@ -397,92 +436,97 @@ def kernel_unified_attention_3d( else: V = V_load - seq_mask = seq_offset[None, :] < seq_len + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - # S : (num_queries_per_kv_padded, TILE_SIZE) - S = tl.zeros(shape=(num_queries_per_kv_padded, TILE_SIZE), - dtype=tl.float32) + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) S += scale * tl.dot(Q, K) - context_len = seq_len - 1 - if USE_SOFTCAP: S = apply_softcap(S, softcap) - S = tl.where(head_mask[:, None] & seq_mask, S, - float("-inf")).to(tl.float32) + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, + S, float("-inf")) if SLIDING_WINDOW > 0: - S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, - float("-inf")) + S = tl.where((context_len + query_pos[:, None] - seq_offset) + < SLIDING_WINDOW, S, float("-inf")) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) # compute running maximum - # m_j : (num_queries_per_kv_padded,) + # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (num_queries_per_kv_padded, TILE_SIZE,) + # P : (BLOCK_M, TILE_SIZE,) P = tl.exp(S - m_j[:, None]) - # l_j : (num_queries_per_kv_padded,) + # l_j : (BLOCK_M,) l_j = tl.sum(P, axis=1) - # alpha : (num_queries_per_kv_padded, ) + # alpha : (BLOCK_M, ) alpha = tl.exp(M - m_j) - # acc : (num_queries_per_kv_padded, HEAD_SIZE_PADDED) + # acc : (BLOCK_M, HEAD_SIZE_PADDED) acc = acc * alpha[:, None] # update constants L = L * alpha + l_j M = m_j - # acc : (num_queries_per_kv_padded, HEAD_SIZE_PADDED) + # acc : (BLOCK_M, HEAD_SIZE_PADDED) acc += tl.dot(P.to(V.dtype), V) segm_output_offset = ( - seq_idx.to(tl.int64) * + query_offset_0[:, None].to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + - query_head_idx[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) tl.store( segm_output_ptr + segm_output_offset, acc, - mask=dim_mask[None, :] & head_mask[:, None], + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], ) - segm_offset = (seq_idx.to(tl.int64) * + segm_offset = (query_offset_0.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + - query_head_idx * NUM_SEGMENTS_PER_SEQ + segm_idx) - tl.store(segm_max_ptr + segm_offset, M, mask=head_mask) - tl.store(segm_expsum_ptr + segm_offset, L, mask=head_mask) + query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx) + tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) + tl.store(segm_expsum_ptr + segm_offset, + L, + mask=query_mask_0 & query_mask_1) @triton.jit def reduce_segments( - output_ptr, # [num_seqs, num_query_heads, head_size] + output_ptr, # [num_tokens, num_query_heads, head_size] segm_output_ptr, - #[num_seqs, num_query_heads, max_num_segments, head_size] - segm_max_ptr, # [num_seqs, num_query_heads, max_num_segments] - segm_expsum_ptr, # [num_seqs, num_query_heads, max_num_segments] + #[num_tokens, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] seq_lens_ptr, # [num_seqs] + num_seqs, # int num_query_heads: tl.constexpr, # int output_stride_0: tl.int64, # int output_stride_1: tl.int64, # int, should be equal to head_size block_table_stride: tl.int64, # int - TILE_SIZE: tl.constexpr, # int, must be power of 2 + TILE_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int ): - seq_idx = tl.program_id(0) + query_token_idx = tl.program_id(0) query_head_idx = tl.program_id(1) + seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs, + BLOCK_Q, False) + # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) @@ -498,7 +542,7 @@ def reduce_segments( 0).to(tl.int1) # load segment maxima - segm_offset = (seq_idx.to(tl.int64) * + segm_offset = (query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + query_head_idx * NUM_SEGMENTS_PER_SEQ + tl.arange(0, NUM_SEGMENTS_PER_SEQ)) @@ -516,7 +560,7 @@ def reduce_segments( # load, rescale, and add segment attention outputs segm_output_offset = ( - seq_idx.to(tl.int64) * + query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + @@ -532,7 +576,7 @@ def reduce_segments( acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) # write result - output_offset = (seq_idx * output_stride_0 + + output_offset = (query_token_idx * output_stride_0 + query_head_idx * output_stride_1 + tl.arange(0, HEAD_SIZE_PADDED)) tl.store(output_ptr + output_offset, acc, mask=dim_mask) @@ -639,8 +683,6 @@ def unified_attention( # for initial version, NUM_SEGMENTS = 16 is chosen as a default # value that showed good performance in tests NUM_SEGMENTS = 16 - num_queries_per_kv_padded = max( - triton.next_power_of_2(num_queries_per_kv), 16) segm_output = torch.empty( q.shape[0], @@ -664,49 +706,55 @@ def unified_attention( dtype=torch.float32, device=q.device, ) - kernel_unified_attention_3d[(num_seqs, num_kv_heads, NUM_SEGMENTS)]( - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, - query_ptr=q, - key_cache_ptr=k, - value_cache_ptr=v, - block_tables_ptr=block_table, - seq_lens_ptr=seqused_k, - alibi_slopes_ptr=alibi_slopes, - scale=softmax_scale, - k_scale=k_descale, - v_scale=v_descale, - softcap=softcap, - num_query_heads=num_query_heads, - num_queries_per_kv=num_queries_per_kv, - num_queries_per_kv_padded=num_queries_per_kv_padded, - block_table_stride=block_table.stride(0), - query_stride_0=q.stride(0), - query_stride_1=q.stride(1), - BLOCK_SIZE=block_size, - TILE_SIZE=TILE_SIZE_DECODE, - HEAD_SIZE=head_size, - HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - USE_ALIBI_SLOPES=use_alibi_slopes, - USE_SOFTCAP=(softcap > 0), - SLIDING_WINDOW=(1 + window_size[0]), - stride_k_cache_0=k.stride(0), - stride_k_cache_1=k.stride(1), - stride_k_cache_2=k.stride(2), - stride_k_cache_3=k.stride(3), - stride_v_cache_0=v.stride(0), - stride_v_cache_1=v.stride(1), - stride_v_cache_2=v.stride(2), - stride_v_cache_3=v.stride(3), - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, - ) - reduce_segments[(num_seqs, num_query_heads)]( + + kernel_unified_attention_3d[( + total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + ) + reduce_segments[(q.shape[0], num_query_heads)]( output_ptr=out, segm_output_ptr=segm_output, segm_max_ptr=segm_max, segm_expsum_ptr=segm_expsum, seq_lens_ptr=seqused_k, + num_seqs=num_seqs, num_query_heads=num_query_heads, output_stride_0=out.stride(0), output_stride_1=out.stride(1), @@ -714,5 +762,7 @@ def unified_attention( TILE_SIZE=TILE_SIZE_DECODE, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, ) From b4fd773ddc2106e8f5a1b7c704c028cbaac3f470 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Sat, 2 Aug 2025 06:45:36 -0400 Subject: [PATCH 09/16] removed changes to triton_attn.py Signed-off-by: Jan van Lunteren --- vllm/v1/attention/backends/triton_attn.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index e2283b8d5cb5..83471ca51b73 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -18,12 +18,9 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills) -from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.gpu_input_batch import InputBatch logger = init_logger(__name__) @@ -75,12 +72,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, vllm_config.parallel_config) self.headdim = model_config.get_head_size() - def reorder_batch(self, input_batch: InputBatch, - scheduler_output: SchedulerOutput) -> bool: - return reorder_batch_to_split_decodes_and_prefills(input_batch, - scheduler_output, - decode_threshold=1) - def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> TritonAttentionMetadata: @@ -189,7 +180,7 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (num_blocks, 2, block_size, num_kv_heads, head_size) + return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: @@ -298,7 +289,7 @@ def forward( key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) else: - key_cache, value_cache = kv_cache.unbind(1) + key_cache, value_cache = kv_cache.unbind(0) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. From 832ccb8d138b34e08ea04a07c27c2d611c26525f Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Sat, 2 Aug 2025 08:02:55 -0400 Subject: [PATCH 10/16] restored changes to triton_attn.py Signed-off-by: Jan van Lunteren --- vllm/v1/attention/backends/triton_attn.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 942cb95eefa2..2a8a5ea71fc0 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -18,10 +18,12 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, AttentionMetadataBuilder, + CommonAttentionMetadata, reorder_batch_to_split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.worker.gpu_input_batch import InputBatch logger = init_logger(__name__) @@ -74,6 +76,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config.parallel_config) self.headdim = model_config.get_head_size() + def reorder_batch(self, input_batch: InputBatch, + scheduler_output: SchedulerOutput) -> bool: + return reorder_batch_to_split_decodes_and_prefills(input_batch, + scheduler_output, + decode_threshold=1) + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> TritonAttentionMetadata: @@ -182,7 +190,7 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + return (num_blocks, 2, block_size, num_kv_heads, head_size) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: @@ -291,7 +299,7 @@ def forward( key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) else: - key_cache, value_cache = kv_cache.unbind(0) + key_cache, value_cache = kv_cache.unbind(1) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. From d893bcb6d3dae0ff1d818e22cfd2ceffc7d02ffe Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Sat, 2 Aug 2025 08:14:24 -0400 Subject: [PATCH 11/16] formatting Signed-off-by: Jan van Lunteren --- vllm/v1/attention/backends/triton_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 2a8a5ea71fc0..b12eb1291619 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -19,10 +19,10 @@ from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, AttentionMetadataBuilder, - CommonAttentionMetadata, reorder_batch_to_split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec + AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + reorder_batch_to_split_decodes_and_prefills) from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.gpu_input_batch import InputBatch logger = init_logger(__name__) From 9b63155d39455fd60e90ca48821c073bd3d20e7b Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Mon, 4 Aug 2025 04:14:18 -0400 Subject: [PATCH 12/16] formatting Signed-off-by: Jan van Lunteren --- vllm/attention/ops/triton_unified_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 6690ac16852a..b224da2841db 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -402,7 +402,7 @@ def kernel_unified_attention_3d( if USE_QQ_BIAS: qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 ) # shape: [BLOCK_M] - + # compute the length of the longest sequence prefix spanned by any # query token in the current q_block (q_block_local_idx) max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( From 39ed6b6e67eb82787a7180e3cfb65e76c8fc1a51 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Tue, 16 Sep 2025 08:51:25 -0400 Subject: [PATCH 13/16] temporarily revert changes to enable simpler merge with latest main Signed-off-by: Jan van Lunteren --- .../attention/ops/triton_unified_attention.py | 115 ++++++++---------- vllm/v1/attention/backends/triton_attn.py | 18 +-- 2 files changed, 53 insertions(+), 80 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 8b506c4622cc..56ebed0f5244 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -70,7 +70,6 @@ def kernel_unified_attention_2d( output_stride_1: tl.int64, # int, should be equal to head_size qq_bias_stride_0: tl.int64, # int BLOCK_SIZE: tl.constexpr, # int - TILE_SIZE: tl.constexpr, # int must be power of 2 HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool @@ -113,7 +112,6 @@ def kernel_unified_attention_2d( offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) - offs_t = tl.arange(0, TILE_SIZE) query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos @@ -173,32 +171,31 @@ def kernel_unified_attention_2d( # actual sequence length max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) - # calculate the number of tiles that need to be processed to - # cover the longest sequence prefix (due to causal masking, tiles beyond + # calculate the number of tiles (blocks) that need to be processed to + # cover the longest sequence prefix (due to causal masking, blocks beyond # this prefix can be skipped) - num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE) # iterate through tiles - for j in range(0, num_tiles): - seq_offset = j * TILE_SIZE + offs_t - tile_mask = seq_offset < max_seq_prefix_len + for j in range(0, num_blocks): - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + - seq_offset // BLOCK_SIZE).to(tl.int64) + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) - v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + + offs_n = tl.arange(0, BLOCK_SIZE) + + v_offset = (physical_block_idx * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 + offs_d[None, :] * stride_v_cache_3 + - (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) + offs_n[:, None] * stride_v_cache_1) - k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + + k_offset = (physical_block_idx * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 + offs_d[:, None] * stride_k_cache_3 + - (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) + offs_n[None, :] * stride_k_cache_1) - # K : (HEAD_SIZE, TILE_SIZE) + # K : (HEAD_SIZE, BLOCK_SIZE) K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None] & tile_mask[None, :], + mask=dim_mask[:, None], other=0.0) if K_load.dtype.is_fp8(): @@ -209,9 +206,9 @@ def kernel_unified_attention_2d( else: K = K_load - # V : (TILE_SIZE, HEAD_SIZE) + # V : (BLOCK_SIZE, HEAD_SIZE) V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :] & tile_mask[:, None], + mask=dim_mask[None, :], other=0.0) if V_load.dtype.is_fp8(): @@ -222,10 +219,12 @@ def kernel_unified_attention_2d( else: V = V_load + seq_offset = j * BLOCK_SIZE + offs_n + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - # S : (BLOCK_M, TILE_SIZE) - S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) + # S : (BLOCK_M, BLOCK_SIZE) + S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) S += scale * tl.dot(Q, K) @@ -257,12 +256,11 @@ def kernel_unified_attention_2d( # compute running maximum # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) - # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_M, TILE_SIZE) + # P : (BLOCK_M, BLOCK_SIZE) P = tl.exp(S - m_j[:, None]) # l_j : (BLOCK_M,) @@ -320,7 +318,6 @@ def kernel_unified_attention_3d( query_stride_1: tl.int64, # int, should be equal to head_size qq_bias_stride_0: tl.int64, # int BLOCK_SIZE: tl.constexpr, # int - TILE_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool @@ -368,19 +365,20 @@ def kernel_unified_attention_3d( # number of segments for this particular sequence num_segments = NUM_SEGMENTS_PER_SEQ - tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) + blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) - if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: + if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len: return offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) - offs_t = tl.arange(0, TILE_SIZE) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_1 = kv_head_idx * num_queries_per_kv + \ offs_m % num_queries_per_kv + query_offset = (query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) @@ -426,44 +424,30 @@ def kernel_unified_attention_3d( qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 ) # shape: [BLOCK_M] - # compute the length of the longest sequence prefix spanned by any - # query token in the current q_block (q_block_local_idx) - max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( - BLOCK_M - 1) // num_queries_per_kv + 1 - - # adjust for potential padding in the last q_block by considering the - # actual sequence length - max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) - - # calculate the number of tiles that need to be processed to - # cover the longest sequence prefix (due to causal masking, tiles beyond - # this prefix can be skipped) - num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) # iterate through tiles within current segment for j in range( - segm_idx * tiles_per_segment, - min((segm_idx + 1) * tiles_per_segment, num_tiles), + segm_idx * blocks_per_segment, + min((segm_idx + 1) * blocks_per_segment, num_blocks), ): - seq_offset = j * TILE_SIZE + offs_t - tile_mask = seq_offset < max_seq_prefix_len + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + - seq_offset // BLOCK_SIZE).to(tl.int64) + offs_n = tl.arange(0, BLOCK_SIZE) - v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + + v_offset = (physical_block_idx * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 + offs_d[None, :] * stride_v_cache_3 + - (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) + offs_n[:, None] * stride_v_cache_1) - k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + + k_offset = (physical_block_idx * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 + offs_d[:, None] * stride_k_cache_3 + - (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) + offs_n[None, :] * stride_k_cache_1) - # K : (HEAD_SIZE, TILE_SIZE) + # K : (HEAD_SIZE, BLOCK_SIZE) K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None] & tile_mask[None, :], + mask=dim_mask[:, None], other=0.0) if K_load.dtype.is_fp8(): @@ -474,9 +458,9 @@ def kernel_unified_attention_3d( else: K = K_load - # V : (TILE_SIZE, HEAD_SIZE) + # V : (BLOCK_SIZE, HEAD_SIZE) V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :] & tile_mask[:, None], + mask=dim_mask[None, :], other=0.0) if V_load.dtype.is_fp8(): @@ -487,10 +471,13 @@ def kernel_unified_attention_3d( else: V = V_load + seq_offset = j * BLOCK_SIZE + offs_n + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - # S : (BLOCK_M, TILE_SIZE) - S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) + # S : (BLOCK_M, BLOCK_SIZE) + S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) + S += scale * tl.dot(Q, K) if USE_SOFTCAP: @@ -521,12 +508,11 @@ def kernel_unified_attention_3d( # compute running maximum # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) - # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_M, TILE_SIZE,) + # P : (BLOCK_M, BLOCK_SIZE,) P = tl.exp(S - m_j[:, None]) # l_j : (BLOCK_M,) @@ -577,7 +563,7 @@ def reduce_segments( output_stride_0: tl.int64, # int output_stride_1: tl.int64, # int, should be equal to head_size block_table_stride: tl.int64, # int - TILE_SIZE: tl.constexpr, # int + BLOCK_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 query_start_len_ptr, # [num_seqs+1] @@ -595,10 +581,10 @@ def reduce_segments( # number of segments for this particular sequence num_segments = NUM_SEGMENTS_PER_SEQ - tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) + blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) # create masks for subsequent loads - act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE) + act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE) segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32) dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, @@ -667,7 +653,6 @@ def unified_attention( # Optional tensor for sinks sinks=None, ): - assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -703,9 +688,6 @@ def unified_attention( # = floor(q.shape[0] / BLOCK_Q) + num_seqs total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs - TILE_SIZE_PREFILL = 32 - TILE_SIZE_DECODE = 32 - # if batch contains a prefill if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: kernel_unified_attention_2d[( @@ -734,7 +716,6 @@ def unified_attention( output_stride_1=out.stride(1), qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, BLOCK_SIZE=block_size, - TILE_SIZE=TILE_SIZE_PREFILL, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, @@ -807,7 +788,6 @@ def unified_attention( query_stride_1=q.stride(1), qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, BLOCK_SIZE=block_size, - TILE_SIZE=TILE_SIZE_DECODE, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, @@ -829,6 +809,7 @@ def unified_attention( BLOCK_M=BLOCK_M, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, ) + reduce_segments[(q.shape[0], num_query_heads)]( output_ptr=out, segm_output_ptr=segm_output, @@ -840,7 +821,7 @@ def unified_attention( output_stride_0=out.stride(0), output_stride_1=out.stride(1), block_table_stride=block_table.stride(0), - TILE_SIZE=TILE_SIZE_DECODE, + BLOCK_SIZE=block_size, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), query_start_len_ptr=cu_seqlens_q, diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 73823b0d0abb..48a9af3decac 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -18,12 +18,10 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills) -from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.gpu_input_batch import InputBatch logger = init_logger(__name__) @@ -75,12 +73,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config.parallel_config) self.headdim = model_config.get_head_size() - def reorder_batch(self, input_batch: InputBatch, - scheduler_output: SchedulerOutput) -> bool: - return reorder_batch_to_split_decodes_and_prefills(input_batch, - scheduler_output, - decode_threshold=1) - def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> TritonAttentionMetadata: @@ -184,7 +176,7 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (num_blocks, 2, block_size, num_kv_heads, head_size) + return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: @@ -326,7 +318,7 @@ def forward( key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) else: - key_cache, value_cache = kv_cache.unbind(1) + key_cache, value_cache = kv_cache.unbind(0) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. From 4254fad73d273a06409f2cdb6a1216cbf0c8c145 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Tue, 16 Sep 2025 09:52:55 -0400 Subject: [PATCH 14/16] restore changes to triton_unified_attention.py Signed-off-by: Jan van Lunteren --- .../attention/ops/triton_unified_attention.py | 115 ++++++++++-------- 1 file changed, 67 insertions(+), 48 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index d2ad2f7e8d2a..0a708c6b870d 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -73,6 +73,7 @@ def kernel_unified_attention_2d( output_stride_1: tl.int64, # int, should be equal to head_size qq_bias_stride_0: tl.int64, # int BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int must be power of 2 HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool @@ -118,6 +119,7 @@ def kernel_unified_attention_2d( offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos @@ -177,31 +179,32 @@ def kernel_unified_attention_2d( # actual sequence length max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) - # calculate the number of tiles (blocks) that need to be processed to - # cover the longest sequence prefix (due to causal masking, blocks beyond + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond # this prefix can be skipped) - num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) # iterate through tiles - for j in range(0, num_blocks): + for j in range(0, num_tiles): + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + + seq_offset // BLOCK_SIZE).to(tl.int64) - offs_n = tl.arange(0, BLOCK_SIZE) - - v_offset = (physical_block_idx * stride_v_cache_0 + + v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 + offs_d[None, :] * stride_v_cache_3 + - offs_n[:, None] * stride_v_cache_1) + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) - k_offset = (physical_block_idx * stride_k_cache_0 + + k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 + offs_d[:, None] * stride_k_cache_3 + - offs_n[None, :] * stride_k_cache_1) + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) - # K : (HEAD_SIZE, BLOCK_SIZE) + # K : (HEAD_SIZE, TILE_SIZE) K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None], + mask=dim_mask[:, None] & tile_mask[None, :], other=0.0) if K_load.dtype.is_fp8(): @@ -212,9 +215,9 @@ def kernel_unified_attention_2d( else: K = K_load - # V : (BLOCK_SIZE, HEAD_SIZE) + # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :], + mask=dim_mask[None, :] & tile_mask[:, None], other=0.0) if V_load.dtype.is_fp8(): @@ -225,12 +228,10 @@ def kernel_unified_attention_2d( else: V = V_load - seq_offset = j * BLOCK_SIZE + offs_n - seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - # S : (BLOCK_M, BLOCK_SIZE) - S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) S += scale * tl.dot(Q, K) @@ -262,11 +263,12 @@ def kernel_unified_attention_2d( # compute running maximum # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_M, BLOCK_SIZE) + # P : (BLOCK_M, TILE_SIZE) P = tl.exp(S - m_j[:, None]) # l_j : (BLOCK_M,) @@ -327,6 +329,7 @@ def kernel_unified_attention_3d( query_stride_1: tl.int64, # int, should be equal to head_size qq_bias_stride_0: tl.int64, # int BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool @@ -374,20 +377,19 @@ def kernel_unified_attention_3d( # number of segments for this particular sequence num_segments = NUM_SEGMENTS_PER_SEQ - blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) - if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len: + if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: return offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) - + offs_t = tl.arange(0, TILE_SIZE) query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_1 = kv_head_idx * num_queries_per_kv + \ offs_m % num_queries_per_kv - query_offset = (query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) @@ -433,30 +435,44 @@ def kernel_unified_attention_3d( qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 ) # shape: [BLOCK_M] - num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( + BLOCK_M - 1) // num_queries_per_kv + 1 + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) # iterate through tiles within current segment for j in range( - segm_idx * blocks_per_segment, - min((segm_idx + 1) * blocks_per_segment, num_blocks), + segm_idx * tiles_per_segment, + min((segm_idx + 1) * tiles_per_segment, num_tiles), ): - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len - offs_n = tl.arange(0, BLOCK_SIZE) + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + + seq_offset // BLOCK_SIZE).to(tl.int64) - v_offset = (physical_block_idx * stride_v_cache_0 + + v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 + offs_d[None, :] * stride_v_cache_3 + - offs_n[:, None] * stride_v_cache_1) + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) - k_offset = (physical_block_idx * stride_k_cache_0 + + k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 + offs_d[:, None] * stride_k_cache_3 + - offs_n[None, :] * stride_k_cache_1) + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) - # K : (HEAD_SIZE, BLOCK_SIZE) + # K : (HEAD_SIZE, TILE_SIZE) K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None], + mask=dim_mask[:, None] & tile_mask[None, :], other=0.0) if K_load.dtype.is_fp8(): @@ -467,9 +483,9 @@ def kernel_unified_attention_3d( else: K = K_load - # V : (BLOCK_SIZE, HEAD_SIZE) + # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :], + mask=dim_mask[None, :] & tile_mask[:, None], other=0.0) if V_load.dtype.is_fp8(): @@ -480,13 +496,10 @@ def kernel_unified_attention_3d( else: V = V_load - seq_offset = j * BLOCK_SIZE + offs_n - seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - # S : (BLOCK_M, BLOCK_SIZE) - S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) - + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) S += scale * tl.dot(Q, K) if USE_SOFTCAP: @@ -517,11 +530,12 @@ def kernel_unified_attention_3d( # compute running maximum # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_M, BLOCK_SIZE,) + # P : (BLOCK_M, TILE_SIZE,) P = tl.exp(S - m_j[:, None]) # l_j : (BLOCK_M,) @@ -573,7 +587,7 @@ def reduce_segments( output_stride_0: tl.int64, # int output_stride_1: tl.int64, # int, should be equal to head_size block_table_stride: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 query_start_len_ptr, # [num_seqs+1] @@ -594,10 +608,10 @@ def reduce_segments( # number of segments for this particular sequence num_segments = NUM_SEGMENTS_PER_SEQ - blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) # create masks for subsequent loads - act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE) + act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE) segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32) dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, @@ -671,6 +685,7 @@ def unified_attention( # Optional tensor for sinks sinks=None, ): + assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -707,6 +722,9 @@ def unified_attention( # = floor(q.shape[0] / BLOCK_Q) + num_seqs total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + TILE_SIZE_PREFILL = 32 + TILE_SIZE_DECODE = 32 + # if batch contains a prefill if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: kernel_unified_attention_2d[( @@ -736,6 +754,7 @@ def unified_attention( output_stride_1=out.stride(1), qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_PREFILL, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, @@ -809,6 +828,7 @@ def unified_attention( query_stride_1=q.stride(1), qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, @@ -830,7 +850,6 @@ def unified_attention( BLOCK_M=BLOCK_M, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, ) - reduce_segments[(q.shape[0], num_query_heads)]( output_ptr=out, segm_output_ptr=segm_output, @@ -844,7 +863,7 @@ def unified_attention( output_stride_0=out.stride(0), output_stride_1=out.stride(1), block_table_stride=block_table.stride(0), - BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), query_start_len_ptr=cu_seqlens_q, From 8a846bcf53d39069fc3cfe296930c814d193c953 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Wed, 17 Sep 2025 10:33:56 -0400 Subject: [PATCH 15/16] replaced block size check by tile size check Signed-off-by: Jan van Lunteren --- tests/kernels/attention/test_triton_unified_attention.py | 3 --- vllm/attention/ops/triton_unified_attention.py | 7 +++---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index 4b97d51e6ed2..ab91560e995c 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -102,9 +102,6 @@ def test_triton_unified_attn( ) -> None: torch.set_default_device("cuda") - if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32: - pytest.skip("block size must be at least 32 for fp8") - current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 0a708c6b870d..baebb8004beb 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -689,10 +689,6 @@ def unified_attention( assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" - block_size = v.shape[1] - assert q.element_size() >= 2 or block_size >= 32, \ - "Block size must be at least 32 for fp8" - if sinks is not None: assert sinks.shape[0] == q.shape[1], \ "Sinks must be num_query_heads size" @@ -725,6 +721,9 @@ def unified_attention( TILE_SIZE_PREFILL = 32 TILE_SIZE_DECODE = 32 + assert q.element_size() >= 2 or (TILE_SIZE_PREFILL >= 32 and TILE_SIZE_DECODE >= 32), \ + "tile size must be at least 32 for fp8" + # if batch contains a prefill if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: kernel_unified_attention_2d[( From 3dedd20951996d0f386380ec296bf7a664a7da17 Mon Sep 17 00:00:00 2001 From: Jan van Lunteren Date: Wed, 17 Sep 2025 10:54:51 -0400 Subject: [PATCH 16/16] assign default tile sizes for prefill and decode, remove check Signed-off-by: Jan van Lunteren --- vllm/attention/ops/triton_unified_attention.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index baebb8004beb..591b68bfa646 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -718,11 +718,11 @@ def unified_attention( # = floor(q.shape[0] / BLOCK_Q) + num_seqs total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + # Assigning default tile sizes for prefill and decode. + # Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1) + # and at least 16 for all other data types. TILE_SIZE_PREFILL = 32 - TILE_SIZE_DECODE = 32 - - assert q.element_size() >= 2 or (TILE_SIZE_PREFILL >= 32 and TILE_SIZE_DECODE >= 32), \ - "tile size must be at least 32 for fp8" + TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 # if batch contains a prefill if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: