2121import numpy as np
2222import torch
2323import torch_npu
24- from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
25- AttentionLayer , AttentionType )
26- from vllm .attention .backends .utils import PAD_SLOT_ID , CommonAttentionState
27- from vllm .v1 .core .sched .output import SchedulerOutput
28-
29- from vllm_ascend .attention .attention_v1 import AscendAttentionState
24+ from vllm .attention .backends .abstract import (AttentionImpl , AttentionLayer ,
25+ AttentionType )
26+ from vllm .attention .backends .utils import PAD_SLOT_ID
27+
28+ from vllm_ascend .attention .attention_v1 import (AscendAttentionBackend ,
29+ AscendAttentionMetadataBuilder ,
30+ AscendAttentionState ,
31+ AscendMetadata )
3032from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , aligned_16 , is_310p ,
3133 nd_to_nz_2d )
32- from vllm_ascend .worker .npu_input_batch import InputBatch
3334
3435
35- class AscendAttentionTorchairBackend (AttentionBackend ):
36+ class AscendAttentionTorchairBackend (AscendAttentionBackend ):
3637 accept_output_buffer : bool = True
3738
3839 @staticmethod
@@ -47,10 +48,6 @@ def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
4748 def get_metadata_cls () -> Type ["AscendTorchairMetadata" ]:
4849 return AscendTorchairMetadata
4950
50- @staticmethod
51- def get_state_cls () -> Type ["CommonAttentionState" ]:
52- return CommonAttentionState
53-
5451 @staticmethod
5552 def get_builder_cls () -> type ["AscendAttentionTorchairMetadataBuilder" ]:
5653 return AscendAttentionTorchairMetadataBuilder
@@ -73,36 +70,6 @@ def get_bsh_kv_cache_shape(
7370 ) -> Tuple [int , ...]:
7471 return (2 , num_blocks , block_size , num_kv_heads * head_size )
7572
76- @staticmethod
77- def swap_blocks (
78- src_kv_cache : List [torch .Tensor ],
79- dst_kv_cache : List [torch .Tensor ],
80- src_to_dst : torch .Tensor ,
81- ) -> None :
82- src_key_cache , src_value_cache = src_kv_cache [0 ], src_kv_cache [1 ]
83- dst_key_cache , dst_value_cache = dst_kv_cache [0 ], dst_kv_cache [1 ]
84- src_indices = src_to_dst [:, 0 ]
85- dst_indices = src_to_dst [:, 1 ]
86-
87- dst_key_cache [dst_indices ] = src_key_cache [src_indices ].to (
88- dst_key_cache .device )
89- dst_value_cache [dst_indices ] = src_value_cache [src_indices ].to (
90- dst_key_cache .device )
91-
92- @staticmethod
93- def copy_blocks (
94- kv_caches : List [torch .Tensor ],
95- src_to_dists : torch .Tensor ,
96- ) -> None :
97- src_indices = src_to_dists [:, 0 ]
98- dst_indices = src_to_dists [:, 1 ]
99-
100- for kv_cache in kv_caches :
101- key_caches = kv_cache [0 ]
102- value_caches = kv_cache [1 ]
103- key_caches [dst_indices ] = key_caches [src_indices ]
104- value_caches [dst_indices ] = value_caches [src_indices ]
105-
10673
10774@dataclass
10875class AscendDecodeMetadata :
@@ -117,40 +84,15 @@ class AscendDecodeMetadata:
11784
11885
11986@dataclass
120- class AscendTorchairMetadata :
121- num_actual_tokens : int # Number of tokens excluding padding.
122- # (batch_size, max_blocks_per_seq).
123- # Block addresses per sequence. (Seq id -> list of physical block)
124- block_tables : torch .Tensor
125- # (batch_size,). The sequence length per sequence. Sequence length means
126- # the computed tokens + new tokens None if it is a decoding.
127- query_start_loc : torch .Tensor
128- query_lens : torch .Tensor
129- seq_lens : torch .Tensor
130- # Maximum query length in the batch. None for decoding.
131- max_query_len : Optional [int ] = None
132- # (num_tokens,). The indices of the token slots that input tokens will be
133- # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
134- # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
135- # in block 0, and 1st slot in block 1, respectively.
136- slot_mapping : torch .Tensor = None
137- # Current state of this attention run.
138- attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill
139- attn_mask : Optional [torch .Tensor ] = None
87+ class AscendTorchairMetadata (AscendMetadata ):
14088
14189 decode : Optional [AscendDecodeMetadata ] = None
14290
143- enable_dbo_across_dp : bool = False
14491
145-
146- class AscendAttentionTorchairMetadataBuilder :
92+ class AscendAttentionTorchairMetadataBuilder (AscendAttentionMetadataBuilder ):
14793
14894 def __init__ (self , runner ):
149- self .runner = runner
150-
151- def reorder_batch (self , input_batch : "InputBatch" ,
152- scheduler_output : "SchedulerOutput" ) -> bool :
153- return False
95+ super ().__init__ (runner )
15496
15597 def _get_graph_runner_block_tables (
15698 self , num_seqs : int , block_tables : torch .Tensor ) -> torch .Tensor :
@@ -222,11 +164,16 @@ def build(self,
222164 num_reqs ,
223165 num_actual_tokens ,
224166 max_query_len ,
225- graph_pad_size : int = - 1 ,
226167 enable_dbo_across_dp : bool = False ,
168+ is_only_prefill : bool = False ,
227169 * args ,
228170 ** kwargs ):
229171
172+ if 'graph_pad_size' in kwargs :
173+ graph_pad_size = kwargs ['graph_pad_size' ]
174+ else :
175+ graph_pad_size = - 1 # default value
176+
230177 device = self .runner .device
231178
232179 block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
0 commit comments