1717from vllm_ascend .attention .attention_v1 import AscendAttentionState
1818from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
1919from vllm_ascend .utils import vllm_version_is
20- from vllm_ascend .worker .model_runner_v1 import NPUModelRunner
2120
2221if TYPE_CHECKING :
2322 from vllm .v1 .core .sched .output import SchedulerOutput
2423 from vllm .v1 .worker .gpu_input_batch import InputBatch
2524
2625
26+ @dataclass
27+ class CommonAttentionMetadata :
28+ """
29+ Attention metadata attributes that can be shared by layers in different KV
30+ cache groups and thus having different block table.
31+ """
32+
33+ query_start_loc : torch .Tensor
34+ """(batch_size + 1,), the start location of each request in query Tensor"""
35+ seq_lens : torch .Tensor
36+ """(batch_size,), the length of each request including both computed tokens
37+ and newly scheduled tokens"""
38+
39+
2740class AscendMLABackend (AttentionBackend ):
2841
2942 accept_output_buffer : bool = True
@@ -58,6 +71,7 @@ class AscendMLAPrefillMetadata:
5871 seq_lens : list [int ]
5972 context_lens : torch .Tensor
6073 input_positions : torch .Tensor
74+ query_start_loc : torch .Tensor
6175 block_table : torch .Tensor
6276 max_query_len : int
6377 max_seq_lens : int
@@ -91,6 +105,9 @@ class AscendMLAMetadata:
91105
92106 num_actual_tokens : int # Number of tokens excluding padding.
93107 slot_mapping : torch .Tensor
108+ query_start_loc : torch .Tensor
109+ seq_lens : torch .Tensor
110+ block_tables : torch .Tensor
94111
95112 # New for MLA (compared to FlashAttention)
96113 # For handling prefill decode split
@@ -131,7 +148,7 @@ class AscendMLAMetadataBuilder:
131148
132149 # _attn_mask_builder = None
133150 def __init__ (self ,
134- runner : "NPUModelRunner" ,
151+ runner ,
135152 metadata_cls : Optional [AscendMLAMetadata ] = None ):
136153 self .metadata_cls : Optional [AscendMLAMetadata ] = metadata_cls \
137154 if metadata_cls is not None else AscendMLAMetadata # type: ignore
@@ -231,6 +248,7 @@ def build(self,
231248 num_reqs : int ,
232249 num_actual_tokens : int ,
233250 max_query_len : int ,
251+ common_attn_metadata : CommonAttentionMetadata ,
234252 common_prefix_len : Optional [int ] = None ,
235253 graph_pad_size : int = - 1 ) -> AscendMLAMetadata :
236254 assert self ._num_decodes + self ._num_prefills == num_reqs
@@ -243,10 +261,8 @@ def build(self,
243261 block_table = (self .runner .input_batch .block_table .
244262 get_device_tensor ()[:num_reqs ])
245263 else :
246- block_table = self .runner .input_batch .block_table [
247- 0 ].get_device_tensor ()
248- block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
249- block_table [:num_reqs ])
264+ block_table = (self .runner .input_batch .block_table [0 ].
265+ get_device_tensor ()[:num_reqs ])
250266 slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
251267 device , non_blocking = True )
252268 input_positions = self .runner .positions_cpu [:num_actual_tokens ].to (
@@ -258,13 +274,17 @@ def build(self,
258274 seq_lens = seq_lens_cpu
259275 max_query_len = query_lens .max ().item ()
260276 max_seq_lens = seq_lens .max ().item ()
277+ query_start_loc = None
261278
262279 prefill_metadata = None
263280 if self ._num_prefills > 0 :
264281 reqs_start = self ._num_decodes # prefill_start
265282 tokens_start = self ._num_decode_tokens
266283 max_query_len = query_lens [tokens_start :].max ().item ()
267284 max_seq_lens = seq_lens [tokens_start :].max ().item ()
285+ query_start_loc = common_attn_metadata .query_start_loc
286+ prefill_query_start_loc = query_start_loc [
287+ reqs_start :] - query_start_loc [reqs_start ]
268288
269289 prefill_metadata = AscendMLAPrefillMetadata (
270290 attn_mask = self .runner .attn_mask ,
@@ -275,6 +295,7 @@ def build(self,
275295 block_table = block_table [reqs_start :, ...],
276296 max_query_len = max_query_len ,
277297 max_seq_lens = max_seq_lens ,
298+ query_start_loc = prefill_query_start_loc ,
278299 )
279300
280301 decode_metadata = None
@@ -331,6 +352,9 @@ def build(self,
331352 attn_state = self .runner .attn_state ,
332353 prefill = prefill_metadata ,
333354 decode = decode_metadata ,
355+ query_start_loc = query_start_loc ,
356+ block_tables = block_table ,
357+ seq_lens = seq_lens ,
334358 )
335359
336360
0 commit comments