1616
1717from vllm_ascend .attention .attention_v1 import AscendAttentionState
1818from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
19- from vllm_ascend .worker .model_runner_v1 import NPUModelRunner
2019
2120if TYPE_CHECKING :
2221 from vllm .v1 .core .sched .output import SchedulerOutput
2322 from vllm .v1 .worker .gpu_input_batch import InputBatch
2423
2524
25+ @dataclass
26+ class CommonAttentionMetadata :
27+ """
28+ Attention metadata attributes that can be shared by layers in different KV
29+ cache groups and thus having different block table.
30+ """
31+
32+ query_start_loc : torch .Tensor
33+ """(batch_size + 1,), the start location of each request in query Tensor"""
34+ seq_lens : torch .Tensor
35+ """(batch_size,), the length of each request including both computed tokens
36+ and newly scheduled tokens"""
37+
38+
2639class AscendMLABackend (AttentionBackend ):
2740
2841 accept_output_buffer : bool = True
@@ -57,6 +70,7 @@ class AscendMLAPrefillMetadata:
5770 seq_lens : list [int ]
5871 context_lens : torch .Tensor
5972 input_positions : torch .Tensor
73+ query_start_loc : torch .Tensor
6074 block_table : torch .Tensor
6175 max_query_len : int
6276 max_seq_lens : int
@@ -90,6 +104,9 @@ class AscendMLAMetadata:
90104
91105 num_actual_tokens : int # Number of tokens excluding padding.
92106 slot_mapping : torch .Tensor
107+ query_start_loc : torch .Tensor
108+ seq_lens : torch .Tensor
109+ block_tables : torch .Tensor
93110
94111 # New for MLA (compared to FlashAttention)
95112 # For handling prefill decode split
@@ -130,7 +147,7 @@ class AscendMLAMetadataBuilder:
130147
131148 # _attn_mask_builder = None
132149 def __init__ (self ,
133- runner : "NPUModelRunner" ,
150+ runner ,
134151 metadata_cls : Optional [AscendMLAMetadata ] = None ):
135152 self .metadata_cls : Optional [AscendMLAMetadata ] = metadata_cls \
136153 if metadata_cls is not None else AscendMLAMetadata # type: ignore
@@ -230,6 +247,7 @@ def build(self,
230247 num_reqs : int ,
231248 num_actual_tokens : int ,
232249 max_query_len : int ,
250+ common_attn_metadata : CommonAttentionMetadata ,
233251 common_prefix_len : Optional [int ] = None ,
234252 graph_pad_size : int = - 1 ) -> AscendMLAMetadata :
235253 assert self ._num_decodes + self ._num_prefills == num_reqs
@@ -239,10 +257,8 @@ def build(self,
239257 # it blocks on all previous kernels.
240258 device = self .runner .device
241259
242- block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
243- )
244- block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
245- block_table [:num_reqs ])
260+ block_table = (self .runner .input_batch .block_table [0 ].
261+ get_device_tensor ()[:num_reqs ])
246262 slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
247263 device , non_blocking = True )
248264 input_positions = self .runner .positions_cpu [:num_actual_tokens ].to (
@@ -254,13 +270,17 @@ def build(self,
254270 seq_lens = seq_lens_cpu
255271 max_query_len = query_lens .max ().item ()
256272 max_seq_lens = seq_lens .max ().item ()
273+ query_start_loc = None
257274
258275 prefill_metadata = None
259276 if self ._num_prefills > 0 :
260277 reqs_start = self ._num_decodes # prefill_start
261278 tokens_start = self ._num_decode_tokens
262279 max_query_len = query_lens [tokens_start :].max ().item ()
263280 max_seq_lens = seq_lens [tokens_start :].max ().item ()
281+ query_start_loc = common_attn_metadata .query_start_loc
282+ prefill_query_start_loc = query_start_loc [
283+ reqs_start :] - query_start_loc [reqs_start ]
264284
265285 prefill_metadata = AscendMLAPrefillMetadata (
266286 attn_mask = self .runner .attn_mask ,
@@ -271,6 +291,7 @@ def build(self,
271291 block_table = block_table [reqs_start :, ...],
272292 max_query_len = max_query_len ,
273293 max_seq_lens = max_seq_lens ,
294+ query_start_loc = prefill_query_start_loc ,
274295 )
275296
276297 decode_metadata = None
@@ -327,6 +348,9 @@ def build(self,
327348 attn_state = self .runner .attn_state ,
328349 prefill = prefill_metadata ,
329350 decode = decode_metadata ,
351+ query_start_loc = query_start_loc ,
352+ block_tables = block_table ,
353+ seq_lens = seq_lens ,
330354 )
331355
332356
0 commit comments