Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/attention/ops/triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def kernel_unified_attention_2d(
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
Comment on lines 56 to +63
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make all of the strides constexpr to be safe?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would think its best to not to avoid over recompilation, but for the last stride it makes sense since this is almost always 1 (and when it is 1 we want the compiler to optimize around this, i.e. use wider loads)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoiding recompilation is one reason, but I also think for really long sequences there is a risk that the strides can overflow unless they are explicitly marked as tl.int64. This can't happen for the stride_k_cache_3 and stride_v_cache_3 though, so I think we are safe to do this.

query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
Expand Down
19 changes: 16 additions & 3 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with PagedAttention and Triton prefix prefill."""
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

import torch

Expand All @@ -12,10 +12,23 @@
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable

if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

logger = init_logger(__name__)


class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder):

def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table)
self.aot_schedule = False


class TritonAttentionBackend(AttentionBackend):

accept_output_buffer: bool = True
Expand Down Expand Up @@ -52,8 +65,8 @@ def use_cascade_attention(*args, **kwargs) -> bool:
return False

@staticmethod
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
return TritonAttentionMetadataBuilder


class TritonAttentionImpl(AttentionImpl):
Expand Down