From 1e1491c1d46b136a8d8a289348ac902416f487d9 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 14 Apr 2025 09:42:21 +0000 Subject: [PATCH 01/13] Working version (although slow) Signed-off-by: Thomas Parnell --- tests/kernels/test_flash_attn.py | 7 +- tests/kernels/test_tpa_attn.py | 190 ++++++++++++ .../ops/chunked_prefill_paged_decode.py | 277 +++++++----------- vllm/v1/attention/backends/triton_attn.py | 55 ++-- 4 files changed, 326 insertions(+), 203 deletions(-) create mode 100644 tests/kernels/test_tpa_attn.py diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 572563c0bd82..62753ba6d066 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -258,7 +258,7 @@ def test_varlen_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) - out = torch.empty_like(query) if use_out else None + output = torch.empty_like(query) if use_out else None maybe_quantized_query = query maybe_quantized_key_cache = key_cache @@ -277,11 +277,11 @@ def test_varlen_with_paged_kv( k_descale = torch.ones(scale_shape, dtype=torch.float32) v_descale = torch.ones(scale_shape, dtype=torch.float32) - output = flash_attn_varlen_func( + chunked_prefill_paged_decode( q=maybe_quantized_query, k=maybe_quantized_key_cache, v=maybe_quantized_value_cache, - out=out, + out=output, cu_seqlens_q=cu_query_lens, seqused_k=kv_lens, max_seqlen_q=max_query_len, @@ -296,7 +296,6 @@ def test_varlen_with_paged_kv( k_descale=k_descale, v_descale=v_descale, ) - output = output if not use_out else out ref_output = ref_paged_attn( query=query, diff --git a/tests/kernels/test_tpa_attn.py b/tests/kernels/test_tpa_attn.py new file mode 100644 index 000000000000..daa31252fd6a --- /dev/null +++ b/tests/kernels/test_tpa_attn.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import pytest +import torch + +from vllm.platforms import current_platform +from vllm.attention.ops.chunked_prefill_paged_decode import ( + chunked_prefill_paged_decode) + +NUM_HEADS = [(4, 4), (8, 2), (16, 2)] +HEAD_SIZES = [128, 256] +BLOCK_SIZES = [16, 32] +DTYPES = [torch.float16, torch.bfloat16] +QDTYPES = [None] +# one value large enough to test overflow in index calculation. +# one value small enough to test the schema op check +NUM_BLOCKS = [32768, 2048] + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: list[int], + kv_lens: list[int], + block_tables: torch.Tensor, + scale: float, + sliding_window: Optional[int] = None, + soft_cap: Optional[float] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs: list[torch.Tensor] = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx:start_idx + query_len] + q *= scale + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + k = k[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = v[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = torch.triu(empty_mask, + diagonal=kv_len - + (query_len + sliding_window) + + 1).bool().logical_not() + mask |= sliding_window_mask + if soft_cap is not None: + attn = soft_cap * torch.tanh(attn / soft_cap) + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +@pytest.mark.parametrize("seq_lens", + [[(1, 1328), (5, 18), + (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None, 256]) +@pytest.mark.parametrize("dtype", DTYPES) +#@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("soft_cap", [None]) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("q_dtype", QDTYPES) +@torch.inference_mode() +def test_varlen_with_paged_kv( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], + num_blocks: int, + q_dtype: Optional[torch.dtype], +) -> None: + torch.set_default_device("cuda") + + current_platform.seed_everything(0) + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + window_size = ((sliding_window - 1, 0) if sliding_window is not None else + (-1, -1)) + scale = head_size**-0.5 + + query = torch.randn(sum(query_lens), + num_query_heads, + head_size, + dtype=dtype) + key_cache = torch.randn(num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + value_cache = torch.randn_like(key_cache) + cu_query_lens = torch.tensor([0] + query_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + kv_lens = torch.tensor(kv_lens, dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + num_blocks, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + output = torch.empty_like(query) + + maybe_quantized_query = query + maybe_quantized_key_cache = key_cache + maybe_quantized_value_cache = value_cache + q_descale = None + k_descale = None + v_descale = None + if q_dtype is not None: + # QKV are drawn from N(0, 1): no need for a fp8 scaling factor + maybe_quantized_query = query.to(q_dtype) + maybe_quantized_key_cache = key_cache.to(q_dtype) + maybe_quantized_value_cache = value_cache.to(q_dtype) + + scale_shape = (num_seqs, num_kv_heads) + q_descale = torch.ones(scale_shape, dtype=torch.float32) + k_descale = torch.ones(scale_shape, dtype=torch.float32) + v_descale = torch.ones(scale_shape, dtype=torch.float32) + + chunked_prefill_paged_decode( + q=maybe_quantized_query, + k=maybe_quantized_key_cache, + v=maybe_quantized_value_cache, + out=output, + cu_seqlens_q=cu_query_lens, + seqused_k=kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=block_tables, + softcap=soft_cap if soft_cap is not None else 0, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + ) + atol, rtol = 1.5e-2, 1e-2 + if q_dtype is not None: + atol, rtol = 1.5e-1, 1.5e-1 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ + f"{torch.max(torch.abs(output - ref_output))}" diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 1b47581641b0..dff803585ac6 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -21,6 +21,10 @@ def cdiv_fn(x, y): return (x + y - 1) // y +import time +t_prefix = 0.0 +t_paged = 0.0 + @triton.jit def kernel_paged_attention_2d( output_ptr, # [num_tokens, num_query_heads, head_size] @@ -46,37 +50,33 @@ def kernel_paged_attention_2d( HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int - x: tl.constexpr, # int stride_k_cache_0: tl.int64, # int stride_k_cache_1: tl.int64, # int stride_k_cache_2: tl.int64, # int stride_k_cache_3: tl.int64, # int - stride_k_cache_4: tl.int64, # int stride_v_cache_0: tl.int64, # int stride_v_cache_1: tl.int64, # int stride_v_cache_2: tl.int64, # int stride_v_cache_3: tl.int64, # int - filter_by_query_len: tl.constexpr, # bool query_start_len_ptr, # [num_seqs+1] ): seq_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) + q_idx = tl.program_id(2) + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index - if filter_by_query_len: - cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) - cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + - 1) - cur_batch_query_len = cur_batch_in_all_stop_index \ - - cur_batch_in_all_start_index - if cur_batch_query_len > 1: - return - else: - cur_batch_in_all_start_index = seq_idx + if q_idx >= cur_batch_query_len: + return query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange( 0, num_queries_per_kv_padded) - query_offset = (cur_batch_in_all_start_index * query_stride_0 + + query_offset = ((cur_batch_in_all_start_index+q_idx) * query_stride_0 + query_head_idx[:, None] * query_stride_1) head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv @@ -102,6 +102,9 @@ def kernel_paged_attention_2d( # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + # alibi slope for this head if USE_ALIBI_SLOPES: alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx, @@ -119,15 +122,14 @@ def kernel_paged_attention_2d( offs_d = tl.arange(0, HEAD_SIZE_PADDED) v_offset = (physical_block_idx * stride_v_cache_0 + - kv_head_idx * stride_v_cache_1 + - offs_d[None, :] * stride_v_cache_2 + - offs_n[:, None] * stride_v_cache_3) + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + offs_n[:, None] * stride_v_cache_1) k_offset = (physical_block_idx * stride_k_cache_0 + - kv_head_idx * stride_k_cache_1 + - (offs_d[:, None] // x) * stride_k_cache_2 + - offs_n[None, :] * stride_k_cache_3 + - (offs_d[:, None] % x) * stride_k_cache_4) + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + offs_n[None, :] * stride_k_cache_1) # K : (HEAD_SIZE, BLOCK_SIZE) K_load = tl.load(key_cache_ptr + k_offset, @@ -150,7 +152,7 @@ def kernel_paged_attention_2d( V = V_load seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32) + boundary = tl.full([BLOCK_SIZE], context_len+1+q_idx, dtype=tl.int32) seq_mask = seq_offset[None, :] < boundary # S : (num_queries_per_kv, BLOCK_SIZE,) @@ -158,10 +160,8 @@ def kernel_paged_attention_2d( float("-inf")).to(tl.float32) S += scale * tl.dot(Q, K) - context_len = seq_len - 1 - if SLIDING_WINDOW > 0: - S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, + S = tl.where((context_len + q_idx - seq_offset) < SLIDING_WINDOW, S, -10000) if USE_ALIBI_SLOPES: @@ -193,7 +193,7 @@ def kernel_paged_attention_2d( # epilogue acc = acc / L[:, None] - output_offset = (cur_batch_in_all_start_index * output_stride_0 + + output_offset = ((cur_batch_in_all_start_index + q_idx) * output_stride_0 + query_head_idx * output_stride_1) tl.store( @@ -205,162 +205,89 @@ def kernel_paged_attention_2d( def chunked_prefill_paged_decode( - query, - key, - value, - output, - kv_cache_dtype, - key_cache, - value_cache, + q, + k, + v, + out, + cu_seqlens_q, + max_seqlen_q, + seqused_k, + max_seqlen_k, + softmax_scale, + causal, + window_size, block_table, - query_start_loc, - seq_lens, - max_seq_len, - max_query_len, - k_scale, - v_scale, - alibi_slopes=None, - sliding_window=None, - sm_scale=None, + softcap, + q_descale, + k_descale, + v_descale, + alibi_slopes=None ): - if sm_scale is None: - sm_scale = 1.0 / (query.shape[1]**0.5) - use_alibi_slopes = alibi_slopes is not None - if sliding_window is None or sliding_window <= 0: - sliding_window = 0 - - if max_query_len > 1: - context_attention_fwd( - q=query, - k=key, - v=value, - o=output, - kv_cache_dtype=kv_cache_dtype, - k_cache=key_cache, - v_cache=value_cache, - b_loc=block_table, - b_start_loc=query_start_loc, - b_seq_len=seq_lens, - max_seq_len=max_seq_len, - max_input_len=max_query_len, - k_scale=k_scale, - v_scale=v_scale, - alibi_slopes=alibi_slopes, - sliding_window=sliding_window, - sm_scale=sm_scale, - skip_decode=True, - ) - - block_size = value_cache.shape[3] - num_seqs = len(seq_lens) - num_query_heads = query.shape[1] - num_kv_heads = key.shape[1] - num_queries_per_kv = query.shape[1] // key.shape[1] - head_size = query.shape[2] - - # Conversion of FP8 Tensor from uint8 storage to - # appropriate torch.dtype for interpretation by Triton - if "fp8" in kv_cache_dtype: - assert key_cache.dtype == torch.uint8 - assert value_cache.dtype == torch.uint8 - - if kv_cache_dtype in ("fp8", "fp8_e4m3"): - target_dtype = torch.float8_e4m3fn - elif kv_cache_dtype == "fp8_e5m2": - target_dtype = torch.float8_e5m2 - else: - raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + #print("q.shape: ", q.shape) + #print("k.shape: ", k.shape) + #print("v.shape: ", v.shape) + #print("seqused_k: ", seqused_k) + #print("window_size: ", window_size) + #print("cu_seqlens_q: ", cu_seqlens_q) - key_cache = key_cache.view(target_dtype) - value_cache = value_cache.view(target_dtype) + block_size = v.shape[1] + num_seqs = len(seqused_k) + num_query_heads = q.shape[1] + num_kv_heads = k.shape[2] + num_queries_per_kv = num_query_heads // num_kv_heads + head_size = q.shape[2] num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16) - use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, - block_size, - num_queries_per_kv, - max_seq_len, sliding_window) - if use_custom: - _PARTITION_SIZE_ROCM = 256 - max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // - _PARTITION_SIZE_ROCM) - assert _PARTITION_SIZE_ROCM % block_size == 0 - total_num_seq = query.shape[0] - tmp_output = torch.empty( - size=(total_num_seq, num_query_heads, max_num_partitions, - head_size), - dtype=output.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(total_num_seq, num_query_heads, max_num_partitions), - dtype=torch.float32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - - ops.paged_attention_rocm( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale=sm_scale, - block_tables=block_table, - seq_lens=seq_lens, - query_start_loc=query_start_loc, - block_size=block_size, - max_seq_len=max_seq_len, - alibi_slopes=alibi_slopes, - kv_cache_dtype=kv_cache_dtype, - k_scale=k_scale, - v_scale=v_scale, - ) - else: - kernel_paged_attention_2d[( - num_seqs, - num_kv_heads, - )]( - output_ptr=output, - query_ptr=query, - key_cache_ptr=key_cache, - value_cache_ptr=value_cache, - block_tables_ptr=block_table, - seq_lens_ptr=seq_lens, - alibi_slopes_ptr=alibi_slopes, - scale=sm_scale, - k_scale=k_scale, - v_scale=v_scale, - num_query_heads=num_query_heads, - num_queries_per_kv=num_queries_per_kv, - num_queries_per_kv_padded=num_queries_per_kv_padded, - block_table_stride=block_table.stride(0), - query_stride_0=query.stride(0), - query_stride_1=query.stride(1), - output_stride_0=output.stride(0), - output_stride_1=output.stride(1), - BLOCK_SIZE=block_size, - HEAD_SIZE=head_size, - HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - USE_ALIBI_SLOPES=use_alibi_slopes, - SLIDING_WINDOW=sliding_window, - x=key_cache.shape[4], - stride_k_cache_0=key_cache.stride(0), - stride_k_cache_1=key_cache.stride(1), - stride_k_cache_2=key_cache.stride(2), - stride_k_cache_3=key_cache.stride(3), - stride_k_cache_4=key_cache.stride(4), - stride_v_cache_0=value_cache.stride(0), - stride_v_cache_1=value_cache.stride(1), - stride_v_cache_2=value_cache.stride(2), - stride_v_cache_3=value_cache.stride(3), - filter_by_query_len=True, - query_start_len_ptr=query_start_loc, - ) + #print("max_seqlen_q: ", max_seqlen_q) + + #t0 = time.time() + + kernel_paged_attention_2d[( + num_seqs, + num_kv_heads, + max_seqlen_q, + )]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + num_queries_per_kv_padded=num_queries_per_kv_padded, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + SLIDING_WINDOW=(1+window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + ) + + #torch.cuda.synchronize() + #global t_paged + #t_paged += time.time()-t0 + + #print("t_prefix: %.2f seconds, t_paged: %.2f seconds" % (t_prefix, t_paged)) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 5f9610470567..a838f8466dff 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -87,6 +87,11 @@ def __init__( else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + self.use_irope = use_irope assert self.num_heads % self.num_kv_heads == 0 @@ -143,11 +148,9 @@ def forward( # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - # Reshape the input keys and values and store them in the cache. - PagedAttention.write_to_paged_cache( + key_cache, value_cache = kv_cache.unbind(0) + torch.ops._C_cache_ops.reshape_and_cache_flash( key, value, key_cache, @@ -165,34 +168,38 @@ def forward( assert attn_metadata.local_attn_metadata is not None local_metadata = attn_metadata.local_attn_metadata cu_seqlens_q = local_metadata.local_query_start_loc - sequesd_k = local_metadata.local_seqused_k + seqused_k = local_metadata.local_seqused_k max_seqlen_q = local_metadata.local_max_query_len max_seqlen_k = local_metadata.local_max_seq_len block_table = local_metadata.local_block_table else: cu_seqlens_q = attn_metadata.query_start_loc - sequesd_k = attn_metadata.seq_lens + seqused_k = attn_metadata.seq_lens max_seqlen_q = attn_metadata.max_query_len max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table - # Compute attention and update output up to `num_actual_tokens`. - chunked_prefill_paged_decode(query=query[:num_actual_tokens], - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], - output=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - key_cache=key_cache, - value_cache=value_cache, - block_table=block_table, - query_start_loc=cu_seqlens_q, - seq_lens=sequesd_k, - max_seq_len=max_seqlen_k, - max_query_len=max_seqlen_q, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - sm_scale=self.scale) + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + + chunked_prefill_paged_decode( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + return output From 7e44cace6bce510ffeda477b09899ea2b1b39faf Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 14 Apr 2025 19:25:52 +0000 Subject: [PATCH 02/13] working next version Signed-off-by: Thomas Parnell --- tests/kernels/test_tpa_attn.py | 22 +++- .../ops/chunked_prefill_paged_decode.py | 123 ++++++++++++------ vllm/v1/attention/backends/triton_attn.py | 5 + 3 files changed, 107 insertions(+), 43 deletions(-) diff --git a/tests/kernels/test_tpa_attn.py b/tests/kernels/test_tpa_attn.py index daa31252fd6a..b4df6571e8a1 100644 --- a/tests/kernels/test_tpa_attn.py +++ b/tests/kernels/test_tpa_attn.py @@ -9,9 +9,15 @@ from vllm.attention.ops.chunked_prefill_paged_decode import ( chunked_prefill_paged_decode) -NUM_HEADS = [(4, 4), (8, 2), (16, 2)] -HEAD_SIZES = [128, 256] -BLOCK_SIZES = [16, 32] +#NUM_HEADS = [(4, 4), (8, 2), (16, 2)] +NUM_HEADS = [(32, 8)] + +#HEAD_SIZES = [128, 256] +HEAD_SIZES = [128] + +#BLOCK_SIZES = [16, 32] +BLOCK_SIZES = [16] + DTYPES = [torch.float16, torch.bfloat16] QDTYPES = [None] # one value large enough to test overflow in index calculation. @@ -74,9 +80,15 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) +#@pytest.mark.parametrize("seq_lens", +# [[(1, 1328), (5, 18), +# (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) + @pytest.mark.parametrize("seq_lens", - [[(1, 1328), (5, 18), - (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) + [[(1081, 1081)]]) + +#@pytest.mark.parametrize("seq_lens", [[(1, 523), (1, 37), (1, 2011)]]) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index dff803585ac6..c54702a74bc5 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -59,10 +59,19 @@ def kernel_paged_attention_2d( stride_v_cache_2: tl.int64, # int stride_v_cache_3: tl.int64, # int query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int64, + max_seqlen_q : tl.int64 ): - seq_idx = tl.program_id(0) + + max_num_q_blocks = cdiv_fn(max_seqlen_q, BLOCK_Q) + seq_idx = tl.program_id(0) % num_seqs + q_block_idx = tl.program_id(0) // num_seqs kv_head_idx = tl.program_id(1) - q_idx = tl.program_id(2) + + #print("seq_idx: %d, q_block_idx: %d, kv_head_idx: %d" % (seq_idx, q_block_idx, kv_head_idx)) + + #tl.device_print("%d %d %d " % (max_num_q_blocks, seq_idx, q_block_idx)) cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) @@ -70,33 +79,48 @@ def kernel_paged_attention_2d( cur_batch_query_len = cur_batch_in_all_stop_index \ - cur_batch_in_all_start_index - if q_idx >= cur_batch_query_len: + #print("q_block_idx*BLOCK_Q: %d, cur_batch_query_len: %d" % (q_block_idx*BLOCK_Q, cur_batch_query_len)) + + if q_block_idx*BLOCK_Q >= cur_batch_query_len: return - query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange( - 0, num_queries_per_kv_padded) + offs_m = tl.arange(0, BLOCK_Q * num_queries_per_kv) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + + #print("offs_m: ", offs_m) + #print("offs_d: ", offs_d) + + query_pos = q_block_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv - query_offset = ((cur_batch_in_all_start_index+q_idx) * query_stride_0 + - query_head_idx[:, None] * query_stride_1) + #print("query_offset_0: ", query_offset_0) + #print("query_offset_1: ", query_offset_1) - head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv - head_mask = head_mask & (query_head_idx < num_query_heads) + query_offset = (query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :]) - dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, - 0).to(tl.int1) + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) - # Q : (num_queries_per_kv, HEAD_SIZE,) + #print("query_mask_0: ", query_mask_0) + #print("query_mask_1: ", query_mask_1) + + # Q : (BLOCK_Q * num_queries_per_kv, HEAD_SIZE,) Q = tl.load( - query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :], - mask=dim_mask[None, :] & head_mask[:, None], + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], other=0.0, ) block_table_offset = seq_idx * block_table_stride - M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) - L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) - acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], + M = tl.full([BLOCK_Q * num_queries_per_kv], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_Q * num_queries_per_kv], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_Q * num_queries_per_kv, HEAD_SIZE_PADDED], dtype=tl.float32) # sequence len for this particular sequence @@ -107,8 +131,8 @@ def kernel_paged_attention_2d( # alibi slope for this head if USE_ALIBI_SLOPES: - alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx, - mask=head_mask, + alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, + mask=query_mask_1, other=0.0) num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) @@ -119,7 +143,6 @@ def kernel_paged_attention_2d( physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) offs_n = tl.arange(0, BLOCK_SIZE) - offs_d = tl.arange(0, HEAD_SIZE_PADDED) v_offset = (physical_block_idx * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 + @@ -152,55 +175,59 @@ def kernel_paged_attention_2d( V = V_load seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - boundary = tl.full([BLOCK_SIZE], context_len+1+q_idx, dtype=tl.int32) - seq_mask = seq_offset[None, :] < boundary - # S : (num_queries_per_kv, BLOCK_SIZE,) - S = tl.where(head_mask[:, None] & seq_mask, 0.0, + #print("seq_offset: ", seq_offset) + #print("context_len: ", context_len + query_pos) + + #boundary = tl.full([BLOCK_SIZE], context_len+query_lke, dtype=tl.int32) + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, 0.0, float("-inf")).to(tl.float32) S += scale * tl.dot(Q, K) if SLIDING_WINDOW > 0: - S = tl.where((context_len + q_idx - seq_offset) < SLIDING_WINDOW, S, + S = tl.where((context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, S, -10000) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) # compute running maximum - # m_j : (num_queries_per_kv,) + # m_j : (BLOCK_Q * num_queries_per_kv,) m_j = tl.maximum(M, tl.max(S, axis=1)) - # P : (num_queries_per_kv, BLOCK_SIZE,) + # P : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) P = tl.exp(S - m_j[:, None]) - # l_j : (num_queries_per_kv,) + # l_j : (BLOCK_Q * num_queries_per_kv,) l_j = tl.sum(P, axis=1) - # alpha : (num_queries_per_kv, ) + # alpha : (BLOCK_Q * num_queries_per_kv, ) alpha = tl.exp(M - m_j) - # acc : (num_queries_per_kv, BLOCK_SIZE,) + # acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) acc = acc * alpha[:, None] # update constants L = L * alpha + l_j M = m_j - # acc : (num_queries_per_kv, BLOCK_SIZE,) + # acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) acc += tl.dot(P.to(V.dtype), V) # epilogue acc = acc / L[:, None] - output_offset = ((cur_batch_in_all_start_index + q_idx) * output_stride_0 + - query_head_idx * output_stride_1) + output_offset = (query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :]) tl.store( - output_ptr + output_offset[:, None] + - tl.arange(0, HEAD_SIZE_PADDED)[None, :], + output_ptr + output_offset, acc, - mask=dim_mask[None, :] & head_mask[:, None], + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], ) @@ -243,14 +270,31 @@ def chunked_prefill_paged_decode( num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16) + #print("block_size: ", block_size) + #print("num_seqs: ", num_seqs) + #print("num_query_heads: ", num_query_heads) + #print("num_kv_heads: ", num_kv_heads) + #print("head_size: ", head_size) #print("max_seqlen_q: ", max_seqlen_q) + #print("seqused_k: ", seqused_k) #t0 = time.time() + BLOCK_M = 128 + BLOCK_Q = BLOCK_M // num_queries_per_kv + + max_num_query_blocks = triton.cdiv(max_seqlen_q, BLOCK_Q) + num_query_blocks = num_seqs * max_num_query_blocks + + #print("num_queries_per_kv: ", num_queries_per_kv) + #print("BLOCK_Q: ", BLOCK_Q) + #print("num_query_blocks: ", num_query_blocks) + #print("num_seqs: ", num_seqs) + #print("max_num_query_blocks: ", max_num_query_blocks) + kernel_paged_attention_2d[( - num_seqs, + num_query_blocks, num_kv_heads, - max_seqlen_q, )]( output_ptr=out, query_ptr=q, @@ -284,6 +328,9 @@ def chunked_prefill_paged_decode( stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + max_seqlen_q=max_seqlen_q, ) #torch.cuda.synchronize() diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index a838f8466dff..40576a6ae907 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -181,6 +181,11 @@ def forward( descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + #print("query.shape: ", query.shape) + #print("query.stride: ", query.stride()) + #print("output.shape: ", output.shape) + #print("output.stride: ", output.stride()) + chunked_prefill_paged_decode( q=query[:num_actual_tokens], k=key_cache, From 1212fea21f1aafb441ae5f4dde70aa2fc629fe81 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 14 Apr 2025 20:51:37 +0000 Subject: [PATCH 03/13] Good results Signed-off-by: Thomas Parnell --- .../ops/chunked_prefill_paged_decode.py | 70 +++++++++++++++---- vllm/v1/attention/backends/flash_attn.py | 16 +++++ vllm/v1/attention/backends/triton_attn.py | 4 ++ 3 files changed, 75 insertions(+), 15 deletions(-) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index c54702a74bc5..b49b65a559ec 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -9,6 +9,7 @@ import torch import triton import triton.language as tl +import triton_dejavu from vllm import _custom_ops as ops from vllm.platforms.rocm import use_rocm_custom_paged_attention @@ -25,6 +26,10 @@ def cdiv_fn(x, y): t_prefix = 0.0 t_paged = 0.0 +@triton_dejavu.jitcache( + check_keys=["USE_ALIBI_SLOPES", "SLIDING_WINDOW"], + cache_launch_grid=False, +) @triton.jit def kernel_paged_attention_2d( output_ptr, # [num_tokens, num_query_heads, head_size] @@ -60,15 +65,28 @@ def kernel_paged_attention_2d( stride_v_cache_3: tl.int64, # int query_start_len_ptr, # [num_seqs+1] BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int64, - max_seqlen_q : tl.int64 + num_seqs: tl.int32, + q_block_start_idx_ptr, ): - max_num_q_blocks = cdiv_fn(max_seqlen_q, BLOCK_Q) - seq_idx = tl.program_id(0) % num_seqs - q_block_idx = tl.program_id(0) // num_seqs + q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + mid_val = tl.load(q_block_start_idx_ptr + mid) + if mid_val <= q_block_global_idx: + left = mid + 1 + else: + right = mid + + seq_idx = left - 1 + q_block_start_idx = tl.load(q_block_start_idx_ptr + seq_idx) + + q_block_local_idx = q_block_global_idx - q_block_start_idx + #print("seq_idx: %d, q_block_idx: %d, kv_head_idx: %d" % (seq_idx, q_block_idx, kv_head_idx)) #tl.device_print("%d %d %d " % (max_num_q_blocks, seq_idx, q_block_idx)) @@ -81,7 +99,7 @@ def kernel_paged_attention_2d( #print("q_block_idx*BLOCK_Q: %d, cur_batch_query_len: %d" % (q_block_idx*BLOCK_Q, cur_batch_query_len)) - if q_block_idx*BLOCK_Q >= cur_batch_query_len: + if q_block_local_idx*BLOCK_Q >= cur_batch_query_len: return offs_m = tl.arange(0, BLOCK_Q * num_queries_per_kv) @@ -90,7 +108,7 @@ def kernel_paged_attention_2d( #print("offs_m: ", offs_m) #print("offs_d: ", offs_d) - query_pos = q_block_idx * BLOCK_Q + offs_m // num_queries_per_kv + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv @@ -248,7 +266,9 @@ def chunked_prefill_paged_decode( q_descale, k_descale, v_descale, - alibi_slopes=None + total_num_q_blocks, + cu_seqlens_q_block, + alibi_slopes=None, ): use_alibi_slopes = alibi_slopes is not None @@ -278,13 +298,33 @@ def chunked_prefill_paged_decode( #print("max_seqlen_q: ", max_seqlen_q) #print("seqused_k: ", seqused_k) - #t0 = time.time() - - BLOCK_M = 128 + BLOCK_M = 32 BLOCK_Q = BLOCK_M // num_queries_per_kv - max_num_query_blocks = triton.cdiv(max_seqlen_q, BLOCK_Q) - num_query_blocks = num_seqs * max_num_query_blocks + ''' + t0 = time.time() + torch.cuda.synchronize() + + q_block_start_idx = torch.empty(size=(num_seqs+1,), device=q.device, dtype=torch.int32) + q_block_start_idx[0] = 0 + for i in range(num_seqs): + this_q_len = cu_seqlens_q[i+1]-cu_seqlens_q[i] + this_n_q_blocks = triton.cdiv(this_q_len, BLOCK_Q) + q_block_start_idx[i+1] = q_block_start_idx[i] + this_n_q_blocks + + tot_num_q_blocks = q_block_start_idx[num_seqs] + + q_block_seq_idx = torch.empty(size=(tot_num_q_blocks,), device=q.device, dtype=torch.int32) + for i in range(num_seqs): + start_idx, stop_idx = q_block_start_idx[i], q_block_start_idx[i+1] + q_block_seq_idx[start_idx:stop_idx] = i + + torch.cuda.synchronize() + global t_prefix + t_prefix += time.time()-t0 + ''' + + #t0 = time.time() #print("num_queries_per_kv: ", num_queries_per_kv) #print("BLOCK_Q: ", BLOCK_Q) @@ -293,7 +333,7 @@ def chunked_prefill_paged_decode( #print("max_num_query_blocks: ", max_num_query_blocks) kernel_paged_attention_2d[( - num_query_blocks, + total_num_q_blocks, num_kv_heads, )]( output_ptr=out, @@ -330,7 +370,7 @@ def chunked_prefill_paged_decode( query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, - max_seqlen_q=max_seqlen_q, + q_block_start_idx_ptr=cu_seqlens_q_block, ) #torch.cuda.synchronize() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index b4c7708daab9..59fb82370091 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -5,6 +5,7 @@ import numpy as np import torch +import triton from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -85,6 +86,8 @@ class FlashAttentionMetadata: seq_lens: torch.Tensor block_table: torch.Tensor slot_mapping: torch.Tensor + query_block_start_loc: torch.Tensor + total_num_q_blocks: int # For cascade attention. use_cascade: bool @@ -287,6 +290,17 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int): max_seq_len = self.runner.seq_lens_np[:num_reqs].max() query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] + + # probably a better way + query_block_start_loc_cpu = torch.zeros_like(query_start_loc_cpu) + for i in range(num_reqs): + this_q_len = query_start_loc_cpu[i+1]-query_start_loc_cpu[i] + query_block_start_loc_cpu[i+1] = query_block_start_loc_cpu[i] + triton.cdiv(this_q_len, 8) + total_num_q_blocks = query_block_start_loc_cpu[num_reqs] + + query_block_start_loc = query_block_start_loc_cpu.to(self.runner.device, + non_blocking=True) + query_start_loc = query_start_loc_cpu.to(self.runner.device, non_blocking=True) seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] @@ -339,6 +353,8 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, + query_block_start_loc=query_block_start_loc, + total_num_q_blocks=total_num_q_blocks, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table=block_table, diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 40576a6ae907..84e67a952182 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -178,6 +178,8 @@ def forward( max_seqlen_q = attn_metadata.max_query_len max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table + cu_seqlens_q_block = attn_metadata.query_block_start_loc + total_num_q_blocks = attn_metadata.total_num_q_blocks descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) @@ -204,6 +206,8 @@ def forward( q_descale=layer._q_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + total_num_q_blocks=total_num_q_blocks, + cu_seqlens_q_block=cu_seqlens_q_block, ) From c970a8174d48d522d1d44c07a6357c757c5498ad Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 14 Apr 2025 21:02:00 +0000 Subject: [PATCH 04/13] Leave 16 as default Signed-off-by: Thomas Parnell --- vllm/attention/ops/chunked_prefill_paged_decode.py | 2 +- vllm/v1/attention/backends/flash_attn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index b49b65a559ec..2c4356f031d9 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -298,7 +298,7 @@ def chunked_prefill_paged_decode( #print("max_seqlen_q: ", max_seqlen_q) #print("seqused_k: ", seqused_k) - BLOCK_M = 32 + BLOCK_M = 16 BLOCK_Q = BLOCK_M // num_queries_per_kv ''' diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 59fb82370091..2391598d01bb 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -295,7 +295,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, query_block_start_loc_cpu = torch.zeros_like(query_start_loc_cpu) for i in range(num_reqs): this_q_len = query_start_loc_cpu[i+1]-query_start_loc_cpu[i] - query_block_start_loc_cpu[i+1] = query_block_start_loc_cpu[i] + triton.cdiv(this_q_len, 8) + query_block_start_loc_cpu[i+1] = query_block_start_loc_cpu[i] + triton.cdiv(this_q_len, 4) total_num_q_blocks = query_block_start_loc_cpu[num_reqs] query_block_start_loc = query_block_start_loc_cpu.to(self.runner.device, From 32c9b3b9b05107873589b118d577a71ac6c5ab6d Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 15 Apr 2025 10:09:25 +0000 Subject: [PATCH 05/13] Problem with launch grid Signed-off-by: Thomas Parnell --- tests/kernels/test_flash_attn.py | 7 ++++--- tests/kernels/test_tpa_attn.py | 18 +++++------------- .../ops/chunked_prefill_paged_decode.py | 12 ++++++------ vllm/v1/attention/backends/flash_attn.py | 16 ---------------- vllm/v1/attention/backends/triton_attn.py | 4 ---- 5 files changed, 15 insertions(+), 42 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 62753ba6d066..572563c0bd82 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -258,7 +258,7 @@ def test_varlen_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) - output = torch.empty_like(query) if use_out else None + out = torch.empty_like(query) if use_out else None maybe_quantized_query = query maybe_quantized_key_cache = key_cache @@ -277,11 +277,11 @@ def test_varlen_with_paged_kv( k_descale = torch.ones(scale_shape, dtype=torch.float32) v_descale = torch.ones(scale_shape, dtype=torch.float32) - chunked_prefill_paged_decode( + output = flash_attn_varlen_func( q=maybe_quantized_query, k=maybe_quantized_key_cache, v=maybe_quantized_value_cache, - out=output, + out=out, cu_seqlens_q=cu_query_lens, seqused_k=kv_lens, max_seqlen_q=max_query_len, @@ -296,6 +296,7 @@ def test_varlen_with_paged_kv( k_descale=k_descale, v_descale=v_descale, ) + output = output if not use_out else out ref_output = ref_paged_attn( query=query, diff --git a/tests/kernels/test_tpa_attn.py b/tests/kernels/test_tpa_attn.py index b4df6571e8a1..1bba51720c64 100644 --- a/tests/kernels/test_tpa_attn.py +++ b/tests/kernels/test_tpa_attn.py @@ -9,14 +9,9 @@ from vllm.attention.ops.chunked_prefill_paged_decode import ( chunked_prefill_paged_decode) -#NUM_HEADS = [(4, 4), (8, 2), (16, 2)] -NUM_HEADS = [(32, 8)] - -#HEAD_SIZES = [128, 256] -HEAD_SIZES = [128] - -#BLOCK_SIZES = [16, 32] -BLOCK_SIZES = [16] +NUM_HEADS = [(4, 4), (8, 2), (16, 2)] +HEAD_SIZES = [128, 256] +BLOCK_SIZES = [16, 32] DTYPES = [torch.float16, torch.bfloat16] QDTYPES = [None] @@ -80,12 +75,9 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) -#@pytest.mark.parametrize("seq_lens", -# [[(1, 1328), (5, 18), -# (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) - @pytest.mark.parametrize("seq_lens", - [[(1081, 1081)]]) + [[(1, 1328), (5, 18), + (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) #@pytest.mark.parametrize("seq_lens", [[(1, 523), (1, 37), (1, 2011)]]) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 2c4356f031d9..598e28aee435 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -66,7 +66,6 @@ def kernel_paged_attention_2d( query_start_len_ptr, # [num_seqs+1] BLOCK_Q: tl.constexpr, # int num_seqs: tl.int32, - q_block_start_idx_ptr, ): q_block_global_idx = tl.program_id(0) @@ -76,14 +75,14 @@ def kernel_paged_attention_2d( right = num_seqs while left < right: mid = (left + right) // 2 - mid_val = tl.load(q_block_start_idx_ptr + mid) + mid_val = tl.load(query_start_len_ptr + mid) // BLOCK_Q + mid if mid_val <= q_block_global_idx: left = mid + 1 else: right = mid seq_idx = left - 1 - q_block_start_idx = tl.load(q_block_start_idx_ptr + seq_idx) + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx q_block_local_idx = q_block_global_idx - q_block_start_idx @@ -266,8 +265,6 @@ def chunked_prefill_paged_decode( q_descale, k_descale, v_descale, - total_num_q_blocks, - cu_seqlens_q_block, alibi_slopes=None, ): @@ -332,6 +329,10 @@ def chunked_prefill_paged_decode( #print("num_seqs: ", num_seqs) #print("max_num_query_blocks: ", max_num_query_blocks) + + total_num_q_blocks = cu_seqlens_q[num_seqs].to(device="cpu", non_blocking=False).item() // BLOCK_Q + num_seqs + + kernel_paged_attention_2d[( total_num_q_blocks, num_kv_heads, @@ -370,7 +371,6 @@ def chunked_prefill_paged_decode( query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, - q_block_start_idx_ptr=cu_seqlens_q_block, ) #torch.cuda.synchronize() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 2391598d01bb..b4c7708daab9 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -5,7 +5,6 @@ import numpy as np import torch -import triton from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -86,8 +85,6 @@ class FlashAttentionMetadata: seq_lens: torch.Tensor block_table: torch.Tensor slot_mapping: torch.Tensor - query_block_start_loc: torch.Tensor - total_num_q_blocks: int # For cascade attention. use_cascade: bool @@ -290,17 +287,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int): max_seq_len = self.runner.seq_lens_np[:num_reqs].max() query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] - - # probably a better way - query_block_start_loc_cpu = torch.zeros_like(query_start_loc_cpu) - for i in range(num_reqs): - this_q_len = query_start_loc_cpu[i+1]-query_start_loc_cpu[i] - query_block_start_loc_cpu[i+1] = query_block_start_loc_cpu[i] + triton.cdiv(this_q_len, 4) - total_num_q_blocks = query_block_start_loc_cpu[num_reqs] - - query_block_start_loc = query_block_start_loc_cpu.to(self.runner.device, - non_blocking=True) - query_start_loc = query_start_loc_cpu.to(self.runner.device, non_blocking=True) seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] @@ -353,8 +339,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, - query_block_start_loc=query_block_start_loc, - total_num_q_blocks=total_num_q_blocks, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table=block_table, diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 84e67a952182..40576a6ae907 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -178,8 +178,6 @@ def forward( max_seqlen_q = attn_metadata.max_query_len max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table - cu_seqlens_q_block = attn_metadata.query_block_start_loc - total_num_q_blocks = attn_metadata.total_num_q_blocks descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) @@ -206,8 +204,6 @@ def forward( q_descale=layer._q_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), - total_num_q_blocks=total_num_q_blocks, - cu_seqlens_q_block=cu_seqlens_q_block, ) From 41b572cc405a79f89f5e4871c8b74e1ecbee148e Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 15 Apr 2025 19:20:31 +0000 Subject: [PATCH 06/13] Get total number of query tokens from query shape Signed-off-by: Thomas Parnell --- vllm/attention/ops/chunked_prefill_paged_decode.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 598e28aee435..1b36421cdaa2 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -329,9 +329,11 @@ def chunked_prefill_paged_decode( #print("num_seqs: ", num_seqs) #print("max_num_query_blocks: ", max_num_query_blocks) + #torch.cuda.synchronize() + #print("q.shape[0]: ", q.shape[0]) + #print("cu_seqlens_q[num_seqs]: ", cu_seqlens_q[num_seqs]) - total_num_q_blocks = cu_seqlens_q[num_seqs].to(device="cpu", non_blocking=False).item() // BLOCK_Q + num_seqs - + total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs kernel_paged_attention_2d[( total_num_q_blocks, From 8ee2dad4e6a6f501d5011597ddbe99d9249887af Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 18 Apr 2025 04:39:55 -0400 Subject: [PATCH 07/13] Clean up Signed-off-by: Thomas Parnell --- ...tn.py => test_triton_unified_attention.py} | 18 +-- ..._decode.py => triton_unified_attention.py} | 128 ++++-------------- vllm/v1/attention/backends/triton_attn.py | 22 +-- 3 files changed, 50 insertions(+), 118 deletions(-) rename tests/kernels/{test_tpa_attn.py => test_triton_unified_attention.py} (94%) rename vllm/attention/ops/{chunked_prefill_paged_decode.py => triton_unified_attention.py} (73%) diff --git a/tests/kernels/test_tpa_attn.py b/tests/kernels/test_triton_unified_attention.py similarity index 94% rename from tests/kernels/test_tpa_attn.py rename to tests/kernels/test_triton_unified_attention.py index 1bba51720c64..a798345521f7 100644 --- a/tests/kernels/test_tpa_attn.py +++ b/tests/kernels/test_triton_unified_attention.py @@ -6,15 +6,15 @@ import torch from vllm.platforms import current_platform -from vllm.attention.ops.chunked_prefill_paged_decode import ( - chunked_prefill_paged_decode) +from vllm.attention.ops.triton_unified_attention import ( + unified_attention) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 256] BLOCK_SIZES = [16, 32] DTYPES = [torch.float16, torch.bfloat16] -QDTYPES = [None] +QDTYPES = [None, torch.float8_e4m3fn] # one value large enough to test overflow in index calculation. # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] @@ -63,7 +63,7 @@ def ref_paged_attn( (query_len + sliding_window) + 1).bool().logical_not() mask |= sliding_window_mask - if soft_cap is not None: + if soft_cap is not None and soft_cap > 0: attn = soft_cap * torch.tanh(attn / soft_cap) attn.masked_fill_(mask, float("-inf")) attn = torch.softmax(attn, dim=-1).to(v.dtype) @@ -78,20 +78,16 @@ def ref_paged_attn( @pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) - -#@pytest.mark.parametrize("seq_lens", [[(1, 523), (1, 37), (1, 2011)]]) - @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("sliding_window", [None, 256]) @pytest.mark.parametrize("dtype", DTYPES) -#@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) -@pytest.mark.parametrize("soft_cap", [None]) +@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("q_dtype", QDTYPES) @torch.inference_mode() -def test_varlen_with_paged_kv( +def test_triton_unified_attn( seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], head_size: int, @@ -157,7 +153,7 @@ def test_varlen_with_paged_kv( k_descale = torch.ones(scale_shape, dtype=torch.float32) v_descale = torch.ones(scale_shape, dtype=torch.float32) - chunked_prefill_paged_decode( + unified_attention( q=maybe_quantized_query, k=maybe_quantized_key_cache, v=maybe_quantized_value_cache, diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/triton_unified_attention.py similarity index 73% rename from vllm/attention/ops/chunked_prefill_paged_decode.py rename to vllm/attention/ops/triton_unified_attention.py index 1b36421cdaa2..f5c0ffdbdd5e 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -11,27 +11,23 @@ import triton.language as tl import triton_dejavu -from vllm import _custom_ops as ops -from vllm.platforms.rocm import use_rocm_custom_paged_attention - -from .prefix_prefill import context_attention_fwd - - @triton.jit def cdiv_fn(x, y): return (x + y - 1) // y - -import time -t_prefix = 0.0 -t_paged = 0.0 +@triton.jit +def apply_softcap(S, x): + Sdiv = S/x + p1 = tl.exp(Sdiv) + p2 = tl.exp(-Sdiv) + return x * (p1-p2)/(p1+p2) @triton_dejavu.jitcache( - check_keys=["USE_ALIBI_SLOPES", "SLIDING_WINDOW"], + check_keys=[], cache_launch_grid=False, ) @triton.jit -def kernel_paged_attention_2d( +def kernel_unified_attention_2d( output_ptr, # [num_tokens, num_query_heads, head_size] query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] @@ -42,9 +38,9 @@ def kernel_paged_attention_2d( scale, # float32 k_scale, # float32 v_scale, # float32 + softcap, # float32 num_query_heads: tl.constexpr, # int num_queries_per_kv: tl.constexpr, # int - num_queries_per_kv_padded: tl.constexpr, # int block_table_stride: tl.int64, # int query_stride_0: tl.int64, # int query_stride_1: tl.int64, # int, should be equal to head_size @@ -54,6 +50,7 @@ def kernel_paged_attention_2d( HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int stride_k_cache_0: tl.int64, # int stride_k_cache_1: tl.int64, # int @@ -86,35 +83,23 @@ def kernel_paged_attention_2d( q_block_local_idx = q_block_global_idx - q_block_start_idx - #print("seq_idx: %d, q_block_idx: %d, kv_head_idx: %d" % (seq_idx, q_block_idx, kv_head_idx)) - - #tl.device_print("%d %d %d " % (max_num_q_blocks, seq_idx, q_block_idx)) - cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) cur_batch_query_len = cur_batch_in_all_stop_index \ - cur_batch_in_all_start_index - #print("q_block_idx*BLOCK_Q: %d, cur_batch_query_len: %d" % (q_block_idx*BLOCK_Q, cur_batch_query_len)) - if q_block_local_idx*BLOCK_Q >= cur_batch_query_len: return offs_m = tl.arange(0, BLOCK_Q * num_queries_per_kv) offs_d = tl.arange(0, HEAD_SIZE_PADDED) - #print("offs_m: ", offs_m) - #print("offs_d: ", offs_d) - query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv - #print("query_offset_0: ", query_offset_0) - #print("query_offset_1: ", query_offset_1) - query_offset = (query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) @@ -123,9 +108,6 @@ def kernel_paged_attention_2d( query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) - #print("query_mask_0: ", query_mask_0) - #print("query_mask_1: ", query_mask_1) - # Q : (BLOCK_Q * num_queries_per_kv, HEAD_SIZE,) Q = tl.load( query_ptr + query_offset, @@ -177,7 +159,10 @@ def kernel_paged_attention_2d( other=0.0) if K_load.dtype.is_fp8(): - K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) else: K = K_load @@ -187,23 +172,28 @@ def kernel_paged_attention_2d( other=0.0) if V_load.dtype.is_fp8(): - V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) else: V = V_load seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - #print("seq_offset: ", seq_offset) - #print("context_len: ", context_len + query_pos) - - #boundary = tl.full([BLOCK_SIZE], context_len+query_lke, dtype=tl.int32) seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 # S : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) - S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, 0.0, - float("-inf")).to(tl.float32) + S = tl.zeros(shape=(BLOCK_Q * num_queries_per_kv, BLOCK_SIZE), dtype=tl.float32) + S += scale * tl.dot(Q, K) + if USE_SOFTCAP: + S = apply_softcap(S, softcap) + + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, + float("-inf")) + if SLIDING_WINDOW > 0: S = tl.where((context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, S, -10000) @@ -248,7 +238,7 @@ def kernel_paged_attention_2d( ) -def chunked_prefill_paged_decode( +def unified_attention( q, k, v, @@ -270,13 +260,6 @@ def chunked_prefill_paged_decode( use_alibi_slopes = alibi_slopes is not None - #print("q.shape: ", q.shape) - #print("k.shape: ", k.shape) - #print("v.shape: ", v.shape) - #print("seqused_k: ", seqused_k) - #print("window_size: ", window_size) - #print("cu_seqlens_q: ", cu_seqlens_q) - block_size = v.shape[1] num_seqs = len(seqused_k) num_query_heads = q.shape[1] @@ -284,58 +267,12 @@ def chunked_prefill_paged_decode( num_queries_per_kv = num_query_heads // num_kv_heads head_size = q.shape[2] - num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), - 16) - - #print("block_size: ", block_size) - #print("num_seqs: ", num_seqs) - #print("num_query_heads: ", num_query_heads) - #print("num_kv_heads: ", num_kv_heads) - #print("head_size: ", head_size) - #print("max_seqlen_q: ", max_seqlen_q) - #print("seqused_k: ", seqused_k) - BLOCK_M = 16 BLOCK_Q = BLOCK_M // num_queries_per_kv - ''' - t0 = time.time() - torch.cuda.synchronize() - - q_block_start_idx = torch.empty(size=(num_seqs+1,), device=q.device, dtype=torch.int32) - q_block_start_idx[0] = 0 - for i in range(num_seqs): - this_q_len = cu_seqlens_q[i+1]-cu_seqlens_q[i] - this_n_q_blocks = triton.cdiv(this_q_len, BLOCK_Q) - q_block_start_idx[i+1] = q_block_start_idx[i] + this_n_q_blocks - - tot_num_q_blocks = q_block_start_idx[num_seqs] - - q_block_seq_idx = torch.empty(size=(tot_num_q_blocks,), device=q.device, dtype=torch.int32) - for i in range(num_seqs): - start_idx, stop_idx = q_block_start_idx[i], q_block_start_idx[i+1] - q_block_seq_idx[start_idx:stop_idx] = i - - torch.cuda.synchronize() - global t_prefix - t_prefix += time.time()-t0 - ''' - - #t0 = time.time() - - #print("num_queries_per_kv: ", num_queries_per_kv) - #print("BLOCK_Q: ", BLOCK_Q) - #print("num_query_blocks: ", num_query_blocks) - #print("num_seqs: ", num_seqs) - #print("max_num_query_blocks: ", max_num_query_blocks) - - #torch.cuda.synchronize() - #print("q.shape[0]: ", q.shape[0]) - #print("cu_seqlens_q[num_seqs]: ", cu_seqlens_q[num_seqs]) - total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs - kernel_paged_attention_2d[( + kernel_unified_attention_2d[( total_num_q_blocks, num_kv_heads, )]( @@ -349,9 +286,9 @@ def chunked_prefill_paged_decode( scale=softmax_scale, k_scale=k_descale, v_scale=v_descale, + softcap=softcap, num_query_heads=num_query_heads, num_queries_per_kv=num_queries_per_kv, - num_queries_per_kv_padded=num_queries_per_kv_padded, block_table_stride=block_table.stride(0), query_stride_0=q.stride(0), query_stride_1=q.stride(1), @@ -361,6 +298,7 @@ def chunked_prefill_paged_decode( HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, + USE_SOFTCAP=(softcap > 0), SLIDING_WINDOW=(1+window_size[0]), stride_k_cache_0=k.stride(0), stride_k_cache_1=k.stride(1), @@ -374,9 +312,3 @@ def chunked_prefill_paged_decode( BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, ) - - #torch.cuda.synchronize() - #global t_paged - #t_paged += time.time()-t0 - - #print("t_prefix: %.2f seconds, t_paged: %.2f seconds" % (t_prefix, t_paged)) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 40576a6ae907..d1b63a91e711 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -6,9 +6,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.ops.chunked_prefill_paged_decode import ( - chunked_prefill_paged_decode) -from vllm.attention.ops.paged_attn import PagedAttention +from vllm.attention.ops.triton_unified_attention import ( + unified_attention) from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import ( FlashAttentionMetadata, FlashAttentionMetadataBuilder) @@ -161,6 +160,16 @@ def forward( layer._v_scale, ) + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(torch.float8_e4m3fn) + value_cache = value_cache.view(torch.float8_e4m3fn) + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + use_local_attn = \ (self.use_irope and attn_metadata.local_attn_metadata is not None) @@ -181,12 +190,7 @@ def forward( descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) - #print("query.shape: ", query.shape) - #print("query.stride: ", query.stride()) - #print("output.shape: ", output.shape) - #print("output.stride: ", output.stride()) - - chunked_prefill_paged_decode( + unified_attention( q=query[:num_actual_tokens], k=key_cache, v=value_cache, From ee0053efc8f421ef824a76a47dca95f3b12a31ce Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 18 Apr 2025 04:40:29 -0400 Subject: [PATCH 08/13] fmt Signed-off-by: Thomas Parnell --- .../kernels/test_triton_unified_attention.py | 3 +- .../attention/ops/triton_unified_attention.py | 103 +++++++++--------- vllm/v1/attention/backends/triton_attn.py | 4 +- 3 files changed, 56 insertions(+), 54 deletions(-) diff --git a/tests/kernels/test_triton_unified_attention.py b/tests/kernels/test_triton_unified_attention.py index a798345521f7..1b8de3357b9e 100644 --- a/tests/kernels/test_triton_unified_attention.py +++ b/tests/kernels/test_triton_unified_attention.py @@ -5,9 +5,8 @@ import pytest import torch +from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.platforms import current_platform -from vllm.attention.ops.triton_unified_attention import ( - unified_attention) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 256] diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index f5c0ffdbdd5e..f26fa0d47034 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -6,21 +6,23 @@ # - Chih-Chieh Yang # - Thomas Parnell -import torch import triton import triton.language as tl import triton_dejavu + @triton.jit def cdiv_fn(x, y): return (x + y - 1) // y + @triton.jit def apply_softcap(S, x): - Sdiv = S/x + Sdiv = S / x p1 = tl.exp(Sdiv) p2 = tl.exp(-Sdiv) - return x * (p1-p2)/(p1+p2) + return x * (p1 - p2) / (p1 + p2) + @triton_dejavu.jitcache( check_keys=[], @@ -28,41 +30,41 @@ def apply_softcap(S, x): ) @triton.jit def kernel_unified_attention_2d( - output_ptr, # [num_tokens, num_query_heads, head_size] - query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] - value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - seq_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - scale, # float32 - k_scale, # float32 - v_scale, # float32 - softcap, # float32 - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - block_table_stride: tl.int64, # int - query_stride_0: tl.int64, # int - query_stride_1: tl.int64, # int, should be equal to head_size - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - USE_SOFTCAP: tl.constexpr, # bool - SLIDING_WINDOW: tl.constexpr, # int - stride_k_cache_0: tl.int64, # int - stride_k_cache_1: tl.int64, # int - stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.int64, # int - stride_v_cache_0: tl.int64, # int - stride_v_cache_1: tl.int64, # int - stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.int64, # int - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int32, + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.int64, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.int64, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, ): q_block_global_idx = tl.program_id(0) @@ -79,7 +81,8 @@ def kernel_unified_attention_2d( right = mid seq_idx = left - 1 - q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + q_block_start_idx = tl.load(query_start_len_ptr + + seq_idx) // BLOCK_Q + seq_idx q_block_local_idx = q_block_global_idx - q_block_start_idx @@ -89,7 +92,7 @@ def kernel_unified_attention_2d( cur_batch_query_len = cur_batch_in_all_stop_index \ - cur_batch_in_all_start_index - if q_block_local_idx*BLOCK_Q >= cur_batch_query_len: + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: return offs_m = tl.arange(0, BLOCK_Q * num_queries_per_kv) @@ -101,8 +104,7 @@ def kernel_unified_attention_2d( query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv query_offset = (query_offset_0[:, None] * query_stride_0 + - query_offset_1[:, None] * query_stride_1 + - offs_d[None, :]) + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) @@ -117,7 +119,9 @@ def kernel_unified_attention_2d( block_table_offset = seq_idx * block_table_stride - M = tl.full([BLOCK_Q * num_queries_per_kv], float("-inf"), dtype=tl.float32) + M = tl.full([BLOCK_Q * num_queries_per_kv], + float("-inf"), + dtype=tl.float32) L = tl.full([BLOCK_Q * num_queries_per_kv], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_Q * num_queries_per_kv, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -184,19 +188,20 @@ def kernel_unified_attention_2d( seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 # S : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) - S = tl.zeros(shape=(BLOCK_Q * num_queries_per_kv, BLOCK_SIZE), dtype=tl.float32) + S = tl.zeros(shape=(BLOCK_Q * num_queries_per_kv, BLOCK_SIZE), + dtype=tl.float32) S += scale * tl.dot(Q, K) if USE_SOFTCAP: S = apply_softcap(S, softcap) - S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, - float("-inf")) + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, + S, float("-inf")) if SLIDING_WINDOW > 0: - S = tl.where((context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, S, - -10000) + S = tl.where((context_len + query_pos[:, None] - seq_offset) + < SLIDING_WINDOW, S, -10000) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) @@ -299,7 +304,7 @@ def unified_attention( HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, USE_SOFTCAP=(softcap > 0), - SLIDING_WINDOW=(1+window_size[0]), + SLIDING_WINDOW=(1 + window_size[0]), stride_k_cache_0=k.stride(0), stride_k_cache_1=k.stride(1), stride_k_cache_2=k.stride(2), diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index d1b63a91e711..876fcc96bb5d 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -6,8 +6,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.ops.triton_unified_attention import ( - unified_attention) +from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import ( FlashAttentionMetadata, FlashAttentionMetadataBuilder) @@ -210,5 +209,4 @@ def forward( v_descale=layer._v_scale.expand(descale_shape), ) - return output From 5cb44d95de6ab1ec79ca96e207e0ac31962ffa02 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 18 Apr 2025 04:52:54 -0400 Subject: [PATCH 09/13] add missing file Signed-off-by: Thomas Parnell --- .../ops/chunked_prefill_paged_decode.py | 366 ++++++++++++++++++ 1 file changed, 366 insertions(+) create mode 100644 vllm/attention/ops/chunked_prefill_paged_decode.py diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py new file mode 100644 index 000000000000..1b47581641b0 --- /dev/null +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -0,0 +1,366 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Authors: +# - Burkhard Ringlein +# - Jan van Lunteren +# - Chih-Chieh Yang +# - Thomas Parnell + +import torch +import triton +import triton.language as tl + +from vllm import _custom_ops as ops +from vllm.platforms.rocm import use_rocm_custom_paged_attention + +from .prefix_prefill import context_attention_fwd + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def kernel_paged_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + num_queries_per_kv_padded: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + x: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.int64, # int + stride_k_cache_4: tl.int64, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.int64, # int + filter_by_query_len: tl.constexpr, # bool + query_start_len_ptr, # [num_seqs+1] +): + seq_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + + if filter_by_query_len: + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + + 1) + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + if cur_batch_query_len > 1: + return + else: + cur_batch_in_all_start_index = seq_idx + + query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange( + 0, num_queries_per_kv_padded) + + query_offset = (cur_batch_in_all_start_index * query_stride_0 + + query_head_idx[:, None] * query_stride_1) + + head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv + head_mask = head_mask & (query_head_idx < num_query_heads) + + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, + 0).to(tl.int1) + + # Q : (num_queries_per_kv, HEAD_SIZE,) + Q = tl.load( + query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :], + mask=dim_mask[None, :] & head_mask[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) + L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) + acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], + dtype=tl.float32) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx, + mask=head_mask, + other=0.0) + + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + + # iterate through tiles + for j in range(0, num_blocks): + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + + offs_n = tl.arange(0, BLOCK_SIZE) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + + v_offset = (physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_1 + + offs_d[None, :] * stride_v_cache_2 + + offs_n[:, None] * stride_v_cache_3) + + k_offset = (physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_1 + + (offs_d[:, None] // x) * stride_k_cache_2 + + offs_n[None, :] * stride_k_cache_3 + + (offs_d[:, None] % x) * stride_k_cache_4) + + # K : (HEAD_SIZE, BLOCK_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0) + + if K_load.dtype.is_fp8(): + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (BLOCK_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :], + other=0.0) + + if V_load.dtype.is_fp8(): + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32) + seq_mask = seq_offset[None, :] < boundary + + # S : (num_queries_per_kv, BLOCK_SIZE,) + S = tl.where(head_mask[:, None] & seq_mask, 0.0, + float("-inf")).to(tl.float32) + S += scale * tl.dot(Q, K) + + context_len = seq_len - 1 + + if SLIDING_WINDOW > 0: + S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, + -10000) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (num_queries_per_kv,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + + # P : (num_queries_per_kv, BLOCK_SIZE,) + P = tl.exp(S - m_j[:, None]) + + # l_j : (num_queries_per_kv,) + l_j = tl.sum(P, axis=1) + + # alpha : (num_queries_per_kv, ) + alpha = tl.exp(M - m_j) + + # acc : (num_queries_per_kv, BLOCK_SIZE,) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (num_queries_per_kv, BLOCK_SIZE,) + acc += tl.dot(P.to(V.dtype), V) + + # epilogue + acc = acc / L[:, None] + + output_offset = (cur_batch_in_all_start_index * output_stride_0 + + query_head_idx * output_stride_1) + + tl.store( + output_ptr + output_offset[:, None] + + tl.arange(0, HEAD_SIZE_PADDED)[None, :], + acc, + mask=dim_mask[None, :] & head_mask[:, None], + ) + + +def chunked_prefill_paged_decode( + query, + key, + value, + output, + kv_cache_dtype, + key_cache, + value_cache, + block_table, + query_start_loc, + seq_lens, + max_seq_len, + max_query_len, + k_scale, + v_scale, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, +): + + if sm_scale is None: + sm_scale = 1.0 / (query.shape[1]**0.5) + + use_alibi_slopes = alibi_slopes is not None + + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if max_query_len > 1: + context_attention_fwd( + q=query, + k=key, + v=value, + o=output, + kv_cache_dtype=kv_cache_dtype, + k_cache=key_cache, + v_cache=value_cache, + b_loc=block_table, + b_start_loc=query_start_loc, + b_seq_len=seq_lens, + max_seq_len=max_seq_len, + max_input_len=max_query_len, + k_scale=k_scale, + v_scale=v_scale, + alibi_slopes=alibi_slopes, + sliding_window=sliding_window, + sm_scale=sm_scale, + skip_decode=True, + ) + + block_size = value_cache.shape[3] + num_seqs = len(seq_lens) + num_query_heads = query.shape[1] + num_kv_heads = key.shape[1] + num_queries_per_kv = query.shape[1] // key.shape[1] + head_size = query.shape[2] + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert key_cache.dtype == torch.uint8 + assert value_cache.dtype == torch.uint8 + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + key_cache = key_cache.view(target_dtype) + value_cache = value_cache.view(target_dtype) + + num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), + 16) + + use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, + block_size, + num_queries_per_kv, + max_seq_len, sliding_window) + if use_custom: + _PARTITION_SIZE_ROCM = 256 + max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 + total_num_seq = query.shape[0] + tmp_output = torch.empty( + size=(total_num_seq, num_query_heads, max_num_partitions, + head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(total_num_seq, num_query_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale=sm_scale, + block_tables=block_table, + seq_lens=seq_lens, + query_start_loc=query_start_loc, + block_size=block_size, + max_seq_len=max_seq_len, + alibi_slopes=alibi_slopes, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale, + ) + else: + kernel_paged_attention_2d[( + num_seqs, + num_kv_heads, + )]( + output_ptr=output, + query_ptr=query, + key_cache_ptr=key_cache, + value_cache_ptr=value_cache, + block_tables_ptr=block_table, + seq_lens_ptr=seq_lens, + alibi_slopes_ptr=alibi_slopes, + scale=sm_scale, + k_scale=k_scale, + v_scale=v_scale, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + num_queries_per_kv_padded=num_queries_per_kv_padded, + block_table_stride=block_table.stride(0), + query_stride_0=query.stride(0), + query_stride_1=query.stride(1), + output_stride_0=output.stride(0), + output_stride_1=output.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + SLIDING_WINDOW=sliding_window, + x=key_cache.shape[4], + stride_k_cache_0=key_cache.stride(0), + stride_k_cache_1=key_cache.stride(1), + stride_k_cache_2=key_cache.stride(2), + stride_k_cache_3=key_cache.stride(3), + stride_k_cache_4=key_cache.stride(4), + stride_v_cache_0=value_cache.stride(0), + stride_v_cache_1=value_cache.stride(1), + stride_v_cache_2=value_cache.stride(2), + stride_v_cache_3=value_cache.stride(3), + filter_by_query_len=True, + query_start_len_ptr=query_start_loc, + ) From 13c1c87e6e16f43484b5cec0a267373645a4c925 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 1 May 2025 15:54:08 +0000 Subject: [PATCH 10/13] make jit cache optional + fix precommit Signed-off-by: Lucas Wilkinson review comments Signed-off-by: Lucas Wilkinson review comments + make unit tests pass Signed-off-by: Lucas Wilkinson fix assert Signed-off-by: Lucas Wilkinson --- .../kernels/test_triton_unified_attention.py | 9 +++-- .../attention/ops/triton_unified_attention.py | 39 +++++++++++++++++-- vllm/v1/attention/backends/triton_attn.py | 5 ++- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_triton_unified_attention.py b/tests/kernels/test_triton_unified_attention.py index 1b8de3357b9e..4e15d00255a4 100644 --- a/tests/kernels/test_triton_unified_attention.py +++ b/tests/kernels/test_triton_unified_attention.py @@ -99,6 +99,9 @@ def test_triton_unified_attn( ) -> None: torch.set_default_device("cuda") + if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32: + pytest.skip("block size must be at least 32 for fp8") + current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] @@ -148,9 +151,9 @@ def test_triton_unified_attn( maybe_quantized_value_cache = value_cache.to(q_dtype) scale_shape = (num_seqs, num_kv_heads) - q_descale = torch.ones(scale_shape, dtype=torch.float32) - k_descale = torch.ones(scale_shape, dtype=torch.float32) - v_descale = torch.ones(scale_shape, dtype=torch.float32) + q_descale = None # Not yet supported + k_descale = torch.rand(scale_shape, dtype=torch.float32) + v_descale = torch.rand(scale_shape, dtype=torch.float32) unified_attention( q=maybe_quantized_query, diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index f26fa0d47034..be36884371b5 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -8,7 +8,28 @@ import triton import triton.language as tl -import triton_dejavu + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +try: + from triton_dejavu import jitcache +except ImportError: + + def jitcache(**kwargs): + + def decorator(func): + logger.warning_once( + f"triton_dejavu is not installed. {func.__name__} benefits " + " from jitcache caching, please install triton_dejavu " + " (https://github.com/IBM/triton-dejavu) for best performance." + " \nNOTE: Currently does not support Triton 3.3 (and as a " + " result) PyTorch 2.7.0+. Please open a PR to remove this NOTE " + " once Triton 3.3 is supported.") + return func + + return decorator @triton.jit @@ -24,7 +45,7 @@ def apply_softcap(S, x): return x * (p1 - p2) / (p1 + p2) -@triton_dejavu.jitcache( +@jitcache( check_keys=[], cache_launch_grid=False, ) @@ -101,7 +122,8 @@ def kernel_unified_attention_2d( query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos - query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset_1 = kv_head_idx * num_queries_per_kv + \ + offs_m % num_queries_per_kv query_offset = (query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) @@ -201,7 +223,7 @@ def kernel_unified_attention_2d( if SLIDING_WINDOW > 0: S = tl.where((context_len + query_pos[:, None] - seq_offset) - < SLIDING_WINDOW, S, -10000) + < SLIDING_WINDOW, S, float("-inf")) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) @@ -209,6 +231,9 @@ def kernel_unified_attention_2d( # compute running maximum # m_j : (BLOCK_Q * num_queries_per_kv,) m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) # P : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) P = tl.exp(S - m_j[:, None]) @@ -262,6 +287,12 @@ def unified_attention( v_descale, alibi_slopes=None, ): + assert causal, "Only causal attention is supported" + assert q_descale is None, "Q scales not supported" + + block_size = v.shape[1] + assert q.element_size() >= 2 or block_size >= 32, \ + "Block size must be at least 32 for fp8" use_alibi_slopes = alibi_slopes is not None diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 876fcc96bb5d..bb700c8e2e7a 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -4,6 +4,7 @@ import torch +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.ops.triton_unified_attention import unified_attention @@ -163,6 +164,8 @@ def forward( key_cache = key_cache.view(torch.float8_e4m3fn) value_cache = value_cache.view(torch.float8_e4m3fn) num_tokens, num_heads, head_size = query.shape + assert layer._q_scale == 1.0, \ + "A non 1.0 q_scale is not currently supported." query, _ = ops.scaled_fp8_quant( query.reshape( (num_tokens, num_heads * head_size)).contiguous(), @@ -204,7 +207,7 @@ def forward( window_size=self.sliding_window, block_table=block_table, softcap=self.logits_soft_cap, - q_descale=layer._q_scale.expand(descale_shape), + q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), ) From c00eca3d0ed4bfafd062fe8e9ceab00df20b273f Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 1 May 2025 18:13:07 +0000 Subject: [PATCH 11/13] rip out jitcache since triton dejavu isnt working with PyTorch 2.7 and we have questions around cache keys Signed-off-by: Lucas Wilkinson --- .../attention/ops/triton_unified_attention.py | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index be36884371b5..1d653e4d431e 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -13,24 +13,6 @@ logger = init_logger(__name__) -try: - from triton_dejavu import jitcache -except ImportError: - - def jitcache(**kwargs): - - def decorator(func): - logger.warning_once( - f"triton_dejavu is not installed. {func.__name__} benefits " - " from jitcache caching, please install triton_dejavu " - " (https://github.com/IBM/triton-dejavu) for best performance." - " \nNOTE: Currently does not support Triton 3.3 (and as a " - " result) PyTorch 2.7.0+. Please open a PR to remove this NOTE " - " once Triton 3.3 is supported.") - return func - - return decorator - @triton.jit def cdiv_fn(x, y): @@ -45,10 +27,6 @@ def apply_softcap(S, x): return x * (p1 - p2) / (p1 + p2) -@jitcache( - check_keys=[], - cache_launch_grid=False, -) @triton.jit def kernel_unified_attention_2d( output_ptr, # [num_tokens, num_query_heads, head_size] From 6ff606b1ec885fd3858e175dd31bf25140994eb2 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 6 May 2025 15:52:04 +0000 Subject: [PATCH 12/13] Remove restriction on block size 16 for fp8 since it seems to work Signed-off-by: Thomas Parnell --- tests/kernels/test_triton_unified_attention.py | 3 --- vllm/attention/ops/triton_unified_attention.py | 4 ---- 2 files changed, 7 deletions(-) diff --git a/tests/kernels/test_triton_unified_attention.py b/tests/kernels/test_triton_unified_attention.py index 4e15d00255a4..50da8e5fd5cd 100644 --- a/tests/kernels/test_triton_unified_attention.py +++ b/tests/kernels/test_triton_unified_attention.py @@ -99,9 +99,6 @@ def test_triton_unified_attn( ) -> None: torch.set_default_device("cuda") - if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32: - pytest.skip("block size must be at least 32 for fp8") - current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 1d653e4d431e..680593b15371 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -268,10 +268,6 @@ def unified_attention( assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" - block_size = v.shape[1] - assert q.element_size() >= 2 or block_size >= 32, \ - "Block size must be at least 32 for fp8" - use_alibi_slopes = alibi_slopes is not None block_size = v.shape[1] From 0efb66907b4b25466874ee3b289f240a0c584ea9 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 6 May 2025 18:51:58 +0000 Subject: [PATCH 13/13] Add note about upper-bound Signed-off-by: Thomas Parnell --- vllm/attention/ops/triton_unified_attention.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 680593b15371..8c0cf9267f35 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -280,6 +280,15 @@ def unified_attention( BLOCK_M = 16 BLOCK_Q = BLOCK_M // num_queries_per_kv + # Ideally we would launch with kernel with: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. + # However, it is slow to realize the query_lens on cpu. + # Instead we use upper-bound: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] + # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] + # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs + # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs + # = floor(q.shape[0] / BLOCK_Q) + num_seqs total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs kernel_unified_attention_2d[(