Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/kernels/attention/test_triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
BLOCK_SIZES = [16, 32]

DTYPES = [torch.float16, torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn]
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
None, torch.float8_e4m3fnuz
]
Comment on lines +16 to +18
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect, it should use current_platform.fp8_dtype():

QDTYPES = [None, current_platform.fp8_dtype()]

# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
Expand Down
106 changes: 51 additions & 55 deletions vllm/attention/ops/triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,41 +29,42 @@ def apply_softcap(S, x):

@triton.jit
def kernel_unified_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
scale, # float32
k_scale, # float32
v_scale, # float32
softcap, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
scale, # float32
k_scale, # float32
v_scale, # float32
softcap, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
):

q_block_global_idx = tl.program_id(0)
Expand Down Expand Up @@ -94,23 +95,21 @@ def kernel_unified_attention_2d(
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return

offs_m = tl.arange(0, BLOCK_Q * num_queries_per_kv)
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)

query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv

query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + \
offs_m % num_queries_per_kv

query_offset = (query_offset_0[:, None] * query_stride_0 +
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])

dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)

# Q : (BLOCK_Q * num_queries_per_kv, HEAD_SIZE,)
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
Q = tl.load(
query_ptr + query_offset,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
Expand All @@ -119,12 +118,9 @@ def kernel_unified_attention_2d(

block_table_offset = seq_idx * block_table_stride

M = tl.full([BLOCK_Q * num_queries_per_kv],
float("-inf"),
dtype=tl.float32)
L = tl.full([BLOCK_Q * num_queries_per_kv], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_Q * num_queries_per_kv, HEAD_SIZE_PADDED],
dtype=tl.float32)
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)

# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)
Expand Down Expand Up @@ -183,13 +179,12 @@ def kernel_unified_attention_2d(
else:
V = V_load

seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
seq_offset = j * BLOCK_SIZE + offs_n

seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1

# S : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
S = tl.zeros(shape=(BLOCK_Q * num_queries_per_kv, BLOCK_SIZE),
dtype=tl.float32)
# S : (BLOCK_M, BLOCK_SIZE)
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)

S += scale * tl.dot(Q, K)

Expand All @@ -207,29 +202,29 @@ def kernel_unified_attention_2d(
S += alibi_slope[:, None] * (seq_offset - context_len)

# compute running maximum
# m_j : (BLOCK_Q * num_queries_per_kv,)
# m_j : (BLOCK_M,)
m_j = tl.maximum(M, tl.max(S, axis=1))
# For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)

# P : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
# P : (BLOCK_M, BLOCK_SIZE)
P = tl.exp(S - m_j[:, None])

# l_j : (BLOCK_Q * num_queries_per_kv,)
# l_j : (BLOCK_M,)
l_j = tl.sum(P, axis=1)

# alpha : (BLOCK_Q * num_queries_per_kv, )
# alpha : (BLOCK_M, )
alpha = tl.exp(M - m_j)

# acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc = acc * alpha[:, None]

# update constants
L = L * alpha + l_j
M = m_j

# acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V)

# epilogue
Expand Down Expand Up @@ -334,4 +329,5 @@ def unified_attention(
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
)