| 
1 | 1 | # SPDX-License-Identifier: Apache-2.0  | 
2 | 2 | """Attention layer with PagedAttention and Triton prefix prefill."""  | 
3 |  | -from typing import Any, Optional  | 
 | 3 | +from typing import TYPE_CHECKING, Any, Optional  | 
4 | 4 | 
 
  | 
5 | 5 | import torch  | 
6 | 6 | 
 
  | 
 | 
12 | 12 | from vllm.platforms import current_platform  | 
13 | 13 | from vllm.v1.attention.backends.flash_attn import (  | 
14 | 14 |     FlashAttentionMetadata, FlashAttentionMetadataBuilder)  | 
 | 15 | +from vllm.v1.kv_cache_interface import AttentionSpec  | 
 | 16 | +from vllm.v1.worker.block_table import BlockTable  | 
 | 17 | + | 
 | 18 | +if TYPE_CHECKING:  | 
 | 19 | +    from vllm.v1.worker.gpu_model_runner import GPUModelRunner  | 
15 | 20 | 
 
  | 
16 | 21 | logger = init_logger(__name__)  | 
17 | 22 | 
 
  | 
18 | 23 | 
 
  | 
 | 24 | +class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder):  | 
 | 25 | + | 
 | 26 | +    def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,  | 
 | 27 | +                 block_table: BlockTable):  | 
 | 28 | +        super().__init__(runner, kv_cache_spec, block_table)  | 
 | 29 | +        self.aot_schedule = False  | 
 | 30 | + | 
 | 31 | + | 
19 | 32 | class TritonAttentionBackend(AttentionBackend):  | 
20 | 33 | 
 
  | 
21 | 34 |     accept_output_buffer: bool = True  | 
@@ -52,8 +65,8 @@ def use_cascade_attention(*args, **kwargs) -> bool:  | 
52 | 65 |         return False  | 
53 | 66 | 
 
  | 
54 | 67 |     @staticmethod  | 
55 |  | -    def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:  | 
56 |  | -        return FlashAttentionMetadataBuilder  | 
 | 68 | +    def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:  | 
 | 69 | +        return TritonAttentionMetadataBuilder  | 
57 | 70 | 
 
  | 
58 | 71 | 
 
  | 
59 | 72 | class TritonAttentionImpl(AttentionImpl):  | 
 | 
0 commit comments