Skip to content

Commit c90dc87

Browse files
committed
Fix performance regression for Triton unified attention
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent dc372b9 commit c90dc87

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

vllm/attention/ops/triton_unified_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ def kernel_unified_attention_2d(
5656
stride_k_cache_0: tl.int64, # int
5757
stride_k_cache_1: tl.int64, # int
5858
stride_k_cache_2: tl.int64, # int
59-
stride_k_cache_3: tl.int64, # int
59+
stride_k_cache_3: tl.constexpr, # int
6060
stride_v_cache_0: tl.int64, # int
6161
stride_v_cache_1: tl.int64, # int
6262
stride_v_cache_2: tl.int64, # int
63-
stride_v_cache_3: tl.int64, # int
63+
stride_v_cache_3: tl.constexpr, # int
6464
query_start_len_ptr, # [num_seqs+1]
6565
BLOCK_Q: tl.constexpr, # int
6666
num_seqs: tl.int32,

vllm/v1/attention/backends/triton_attn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import torch
66

77
from vllm import _custom_ops as ops
8-
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
8+
from vllm.attention.backends.abstract import (AttentionBackend,
99
AttentionMetadata, AttentionType)
1010
from vllm.attention.ops.triton_unified_attention import unified_attention
1111
from vllm.logger import init_logger
1212
from vllm.platforms import current_platform
1313
from vllm.v1.attention.backends.flash_attn import (
14-
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
14+
FlashAttentionImpl, FlashAttentionMetadata, FlashAttentionMetadataBuilder)
1515

1616
logger = init_logger(__name__)
1717

@@ -56,7 +56,7 @@ def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
5656
return FlashAttentionMetadataBuilder
5757

5858

59-
class TritonAttentionImpl(AttentionImpl):
59+
class TritonAttentionImpl(FlashAttentionImpl):
6060

6161
def __init__(
6262
self,

0 commit comments

Comments
 (0)