diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 8c0cf9267f35..22a38a05a2a1 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -56,11 +56,11 @@ def kernel_unified_attention_2d( stride_k_cache_0: tl.int64, # int stride_k_cache_1: tl.int64, # int stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int stride_v_cache_0: tl.int64, # int stride_v_cache_1: tl.int64, # int stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int query_start_len_ptr, # [num_seqs+1] BLOCK_Q: tl.constexpr, # int num_seqs: tl.int32, diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index c4922a716bc2..908bf1274125 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Attention layer with PagedAttention and Triton prefix prefill.""" -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import torch @@ -12,10 +12,23 @@ from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import ( FlashAttentionMetadata, FlashAttentionMetadataBuilder) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable + +if TYPE_CHECKING: + from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) +class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder): + + def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, + block_table: BlockTable): + super().__init__(runner, kv_cache_spec, block_table) + self.aot_schedule = False + + class TritonAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @@ -52,8 +65,8 @@ def use_cascade_attention(*args, **kwargs) -> bool: return False @staticmethod - def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: - return FlashAttentionMetadataBuilder + def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: + return TritonAttentionMetadataBuilder class TritonAttentionImpl(AttentionImpl):