Skip to content

Commit 01c2233

Browse files
[Kernel] [V1] Fix performance regression for triton unified attention (#18161)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent 451da4b commit 01c2233

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

vllm/attention/ops/triton_unified_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ def kernel_unified_attention_2d(
5656
stride_k_cache_0: tl.int64, # int
5757
stride_k_cache_1: tl.int64, # int
5858
stride_k_cache_2: tl.int64, # int
59-
stride_k_cache_3: tl.int64, # int
59+
stride_k_cache_3: tl.constexpr, # int
6060
stride_v_cache_0: tl.int64, # int
6161
stride_v_cache_1: tl.int64, # int
6262
stride_v_cache_2: tl.int64, # int
63-
stride_v_cache_3: tl.int64, # int
63+
stride_v_cache_3: tl.constexpr, # int
6464
query_start_len_ptr, # [num_seqs+1]
6565
BLOCK_Q: tl.constexpr, # int
6666
num_seqs: tl.int32,

vllm/v1/attention/backends/triton_attn.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""Attention layer with PagedAttention and Triton prefix prefill."""
3-
from typing import Any, Optional
3+
from typing import TYPE_CHECKING, Any, Optional
44

55
import torch
66

@@ -12,10 +12,23 @@
1212
from vllm.platforms import current_platform
1313
from vllm.v1.attention.backends.flash_attn import (
1414
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
15+
from vllm.v1.kv_cache_interface import AttentionSpec
16+
from vllm.v1.worker.block_table import BlockTable
17+
18+
if TYPE_CHECKING:
19+
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
1520

1621
logger = init_logger(__name__)
1722

1823

24+
class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder):
25+
26+
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
27+
block_table: BlockTable):
28+
super().__init__(runner, kv_cache_spec, block_table)
29+
self.aot_schedule = False
30+
31+
1932
class TritonAttentionBackend(AttentionBackend):
2033

2134
accept_output_buffer: bool = True
@@ -52,8 +65,8 @@ def use_cascade_attention(*args, **kwargs) -> bool:
5265
return False
5366

5467
@staticmethod
55-
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
56-
return FlashAttentionMetadataBuilder
68+
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
69+
return TritonAttentionMetadataBuilder
5770

5871

5972
class TritonAttentionImpl(AttentionImpl):

0 commit comments

Comments
 (0)