diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index 4b97d51e6ed2..ab91560e995c 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -102,9 +102,6 @@ def test_triton_unified_attn( ) -> None: torch.set_default_device("cuda") - if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32: - pytest.skip("block size must be at least 32 for fp8") - current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index d2ad2f7e8d2a..591b68bfa646 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -73,6 +73,7 @@ def kernel_unified_attention_2d( output_stride_1: tl.int64, # int, should be equal to head_size qq_bias_stride_0: tl.int64, # int BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int must be power of 2 HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool @@ -118,6 +119,7 @@ def kernel_unified_attention_2d( offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos @@ -177,31 +179,32 @@ def kernel_unified_attention_2d( # actual sequence length max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) - # calculate the number of tiles (blocks) that need to be processed to - # cover the longest sequence prefix (due to causal masking, blocks beyond + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond # this prefix can be skipped) - num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) # iterate through tiles - for j in range(0, num_blocks): + for j in range(0, num_tiles): + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + + seq_offset // BLOCK_SIZE).to(tl.int64) - offs_n = tl.arange(0, BLOCK_SIZE) - - v_offset = (physical_block_idx * stride_v_cache_0 + + v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 + offs_d[None, :] * stride_v_cache_3 + - offs_n[:, None] * stride_v_cache_1) + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) - k_offset = (physical_block_idx * stride_k_cache_0 + + k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 + offs_d[:, None] * stride_k_cache_3 + - offs_n[None, :] * stride_k_cache_1) + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) - # K : (HEAD_SIZE, BLOCK_SIZE) + # K : (HEAD_SIZE, TILE_SIZE) K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None], + mask=dim_mask[:, None] & tile_mask[None, :], other=0.0) if K_load.dtype.is_fp8(): @@ -212,9 +215,9 @@ def kernel_unified_attention_2d( else: K = K_load - # V : (BLOCK_SIZE, HEAD_SIZE) + # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :], + mask=dim_mask[None, :] & tile_mask[:, None], other=0.0) if V_load.dtype.is_fp8(): @@ -225,12 +228,10 @@ def kernel_unified_attention_2d( else: V = V_load - seq_offset = j * BLOCK_SIZE + offs_n - seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - # S : (BLOCK_M, BLOCK_SIZE) - S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) S += scale * tl.dot(Q, K) @@ -262,11 +263,12 @@ def kernel_unified_attention_2d( # compute running maximum # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_M, BLOCK_SIZE) + # P : (BLOCK_M, TILE_SIZE) P = tl.exp(S - m_j[:, None]) # l_j : (BLOCK_M,) @@ -327,6 +329,7 @@ def kernel_unified_attention_3d( query_stride_1: tl.int64, # int, should be equal to head_size qq_bias_stride_0: tl.int64, # int BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool @@ -374,20 +377,19 @@ def kernel_unified_attention_3d( # number of segments for this particular sequence num_segments = NUM_SEGMENTS_PER_SEQ - blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) - if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len: + if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: return offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) - + offs_t = tl.arange(0, TILE_SIZE) query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_1 = kv_head_idx * num_queries_per_kv + \ offs_m % num_queries_per_kv - query_offset = (query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) @@ -433,30 +435,44 @@ def kernel_unified_attention_3d( qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 ) # shape: [BLOCK_M] - num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( + BLOCK_M - 1) // num_queries_per_kv + 1 + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) # iterate through tiles within current segment for j in range( - segm_idx * blocks_per_segment, - min((segm_idx + 1) * blocks_per_segment, num_blocks), + segm_idx * tiles_per_segment, + min((segm_idx + 1) * tiles_per_segment, num_tiles), ): - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len - offs_n = tl.arange(0, BLOCK_SIZE) + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + + seq_offset // BLOCK_SIZE).to(tl.int64) - v_offset = (physical_block_idx * stride_v_cache_0 + + v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 + offs_d[None, :] * stride_v_cache_3 + - offs_n[:, None] * stride_v_cache_1) + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) - k_offset = (physical_block_idx * stride_k_cache_0 + + k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 + offs_d[:, None] * stride_k_cache_3 + - offs_n[None, :] * stride_k_cache_1) + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) - # K : (HEAD_SIZE, BLOCK_SIZE) + # K : (HEAD_SIZE, TILE_SIZE) K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None], + mask=dim_mask[:, None] & tile_mask[None, :], other=0.0) if K_load.dtype.is_fp8(): @@ -467,9 +483,9 @@ def kernel_unified_attention_3d( else: K = K_load - # V : (BLOCK_SIZE, HEAD_SIZE) + # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :], + mask=dim_mask[None, :] & tile_mask[:, None], other=0.0) if V_load.dtype.is_fp8(): @@ -480,13 +496,10 @@ def kernel_unified_attention_3d( else: V = V_load - seq_offset = j * BLOCK_SIZE + offs_n - seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - # S : (BLOCK_M, BLOCK_SIZE) - S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) - + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) S += scale * tl.dot(Q, K) if USE_SOFTCAP: @@ -517,11 +530,12 @@ def kernel_unified_attention_3d( # compute running maximum # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_M, BLOCK_SIZE,) + # P : (BLOCK_M, TILE_SIZE,) P = tl.exp(S - m_j[:, None]) # l_j : (BLOCK_M,) @@ -573,7 +587,7 @@ def reduce_segments( output_stride_0: tl.int64, # int output_stride_1: tl.int64, # int, should be equal to head_size block_table_stride: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 query_start_len_ptr, # [num_seqs+1] @@ -594,10 +608,10 @@ def reduce_segments( # number of segments for this particular sequence num_segments = NUM_SEGMENTS_PER_SEQ - blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) # create masks for subsequent loads - act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE) + act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE) segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32) dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, @@ -671,13 +685,10 @@ def unified_attention( # Optional tensor for sinks sinks=None, ): + assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" - block_size = v.shape[1] - assert q.element_size() >= 2 or block_size >= 32, \ - "Block size must be at least 32 for fp8" - if sinks is not None: assert sinks.shape[0] == q.shape[1], \ "Sinks must be num_query_heads size" @@ -707,6 +718,12 @@ def unified_attention( # = floor(q.shape[0] / BLOCK_Q) + num_seqs total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + # Assigning default tile sizes for prefill and decode. + # Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1) + # and at least 16 for all other data types. + TILE_SIZE_PREFILL = 32 + TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 + # if batch contains a prefill if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: kernel_unified_attention_2d[( @@ -736,6 +753,7 @@ def unified_attention( output_stride_1=out.stride(1), qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_PREFILL, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, @@ -809,6 +827,7 @@ def unified_attention( query_stride_1=q.stride(1), qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, @@ -830,7 +849,6 @@ def unified_attention( BLOCK_M=BLOCK_M, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, ) - reduce_segments[(q.shape[0], num_query_heads)]( output_ptr=out, segm_output_ptr=segm_output, @@ -844,7 +862,7 @@ def unified_attention( output_stride_0=out.stride(0), output_stride_1=out.stride(1), block_table_stride=block_table.stride(0), - BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), query_start_len_ptr=cu_seqlens_q,