File tree Expand file tree Collapse file tree 2 files changed +5
-5
lines changed Expand file tree Collapse file tree 2 files changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -56,11 +56,11 @@ def kernel_unified_attention_2d(
5656 stride_k_cache_0 : tl .int64 , # int
5757 stride_k_cache_1 : tl .int64 , # int
5858 stride_k_cache_2 : tl .int64 , # int
59- stride_k_cache_3 : tl .int64 , # int
59+ stride_k_cache_3 : tl .constexpr , # int
6060 stride_v_cache_0 : tl .int64 , # int
6161 stride_v_cache_1 : tl .int64 , # int
6262 stride_v_cache_2 : tl .int64 , # int
63- stride_v_cache_3 : tl .int64 , # int
63+ stride_v_cache_3 : tl .constexpr , # int
6464 query_start_len_ptr , # [num_seqs+1]
6565 BLOCK_Q : tl .constexpr , # int
6666 num_seqs : tl .int32 ,
Original file line number Diff line number Diff line change 55import torch
66
77from vllm import _custom_ops as ops
8- from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
8+ from vllm .attention .backends .abstract import (AttentionBackend ,
99 AttentionMetadata , AttentionType )
1010from vllm .attention .ops .triton_unified_attention import unified_attention
1111from vllm .logger import init_logger
1212from vllm .platforms import current_platform
1313from vllm .v1 .attention .backends .flash_attn import (
14- FlashAttentionMetadata , FlashAttentionMetadataBuilder )
14+ FlashAttentionImpl , FlashAttentionMetadata , FlashAttentionMetadataBuilder )
1515
1616logger = init_logger (__name__ )
1717
@@ -56,7 +56,7 @@ def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
5656 return FlashAttentionMetadataBuilder
5757
5858
59- class TritonAttentionImpl (AttentionImpl ):
59+ class TritonAttentionImpl (FlashAttentionImpl ):
6060
6161 def __init__ (
6262 self ,
You can’t perform that action at this time.
0 commit comments