diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index 4e15d00255a4..be3d1879de24 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -13,7 +13,9 @@ BLOCK_SIZES = [16, 32] DTYPES = [torch.float16, torch.bfloat16] -QDTYPES = [None, torch.float8_e4m3fn] +QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [ + None, torch.float8_e4m3fnuz +] # one value large enough to test overflow in index calculation. # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 4bced779785a..87cf333f7f0a 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -29,41 +29,42 @@ def apply_softcap(S, x): @triton.jit def kernel_unified_attention_2d( - output_ptr, # [num_tokens, num_query_heads, head_size] - query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - seq_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - scale, # float32 - k_scale, # float32 - v_scale, # float32 - softcap, # float32 - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - block_table_stride: tl.int64, # int - query_stride_0: tl.int64, # int - query_stride_1: tl.int64, # int, should be equal to head_size - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - USE_SOFTCAP: tl.constexpr, # bool - SLIDING_WINDOW: tl.constexpr, # int - stride_k_cache_0: tl.int64, # int - stride_k_cache_1: tl.int64, # int - stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.constexpr, # int - stride_v_cache_0: tl.int64, # int - stride_v_cache_1: tl.int64, # int - stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.constexpr, # int - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int32, + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int ): q_block_global_idx = tl.program_id(0) @@ -94,15 +95,13 @@ def kernel_unified_attention_2d( if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: return - offs_m = tl.arange(0, BLOCK_Q * num_queries_per_kv) + offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) - query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_1 = kv_head_idx * num_queries_per_kv + \ offs_m % num_queries_per_kv - query_offset = (query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) @@ -110,7 +109,7 @@ def kernel_unified_attention_2d( query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) - # Q : (BLOCK_Q * num_queries_per_kv, HEAD_SIZE,) + # Q : (BLOCK_M, HEAD_SIZE_PADDED) Q = tl.load( query_ptr + query_offset, mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], @@ -119,12 +118,9 @@ def kernel_unified_attention_2d( block_table_offset = seq_idx * block_table_stride - M = tl.full([BLOCK_Q * num_queries_per_kv], - float("-inf"), - dtype=tl.float32) - L = tl.full([BLOCK_Q * num_queries_per_kv], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_Q * num_queries_per_kv, HEAD_SIZE_PADDED], - dtype=tl.float32) + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) @@ -183,13 +179,12 @@ def kernel_unified_attention_2d( else: V = V_load - seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + seq_offset = j * BLOCK_SIZE + offs_n seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - # S : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) - S = tl.zeros(shape=(BLOCK_Q * num_queries_per_kv, BLOCK_SIZE), - dtype=tl.float32) + # S : (BLOCK_M, BLOCK_SIZE) + S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) S += scale * tl.dot(Q, K) @@ -207,29 +202,29 @@ def kernel_unified_attention_2d( S += alibi_slope[:, None] * (seq_offset - context_len) # compute running maximum - # m_j : (BLOCK_Q * num_queries_per_kv,) + # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) + # P : (BLOCK_M, BLOCK_SIZE) P = tl.exp(S - m_j[:, None]) - # l_j : (BLOCK_Q * num_queries_per_kv,) + # l_j : (BLOCK_M,) l_j = tl.sum(P, axis=1) - # alpha : (BLOCK_Q * num_queries_per_kv, ) + # alpha : (BLOCK_M, ) alpha = tl.exp(M - m_j) - # acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) + # acc : (BLOCK_M, HEAD_SIZE_PADDED) acc = acc * alpha[:, None] # update constants L = L * alpha + l_j M = m_j - # acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) + # acc : (BLOCK_M, HEAD_SIZE_PADDED) acc += tl.dot(P.to(V.dtype), V) # epilogue @@ -334,4 +329,5 @@ def unified_attention( query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, + BLOCK_M=BLOCK_M, )