1313 LinearBase , RowParallelLinear ,
1414 UnquantizedLinearMethod )
1515from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
16+ from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
1617
1718from vllm_ascend .attention .attention_v1 import AscendAttentionState
1819from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
@@ -57,6 +58,7 @@ class AscendMLAPrefillMetadata:
5758 seq_lens : list [int ]
5859 context_lens : torch .Tensor
5960 input_positions : torch .Tensor
61+ query_start_loc : torch .Tensor
6062 block_table : torch .Tensor
6163 max_query_len : int
6264 max_seq_lens : int
@@ -90,6 +92,9 @@ class AscendMLAMetadata:
9092
9193 num_actual_tokens : int # Number of tokens excluding padding.
9294 slot_mapping : torch .Tensor
95+ query_start_loc : torch .Tensor
96+ seq_lens : torch .Tensor
97+ block_tables : torch .Tensor
9398
9499 # New for MLA (compared to FlashAttention)
95100 # For handling prefill decode split
@@ -231,6 +236,7 @@ def build(self,
231236 num_actual_tokens : int ,
232237 max_query_len : int ,
233238 common_prefix_len : Optional [int ] = None ,
239+ common_attn_metadata : CommonAttentionMetadata = None ,
234240 graph_pad_size : int = - 1 ) -> AscendMLAMetadata :
235241 assert self ._num_decodes + self ._num_prefills == num_reqs
236242
@@ -245,6 +251,7 @@ def build(self,
245251 input_positions = self .runner .positions_cpu [:num_actual_tokens ].to (
246252 device , non_blocking = True ).long ()
247253
254+ query_start_loc = common_attn_metadata .query_start_loc
248255 seq_lens_cpu = self .runner .seq_lens_cpu [:num_reqs ]
249256 query_lens = seq_lens_cpu - self .runner .input_batch .num_computed_tokens_cpu_tensor [:
250257 num_reqs ]
@@ -258,6 +265,8 @@ def build(self,
258265 tokens_start = self ._num_decode_tokens
259266 max_query_len = query_lens [tokens_start :].max ().item ()
260267 max_seq_lens = seq_lens [tokens_start :].max ().item ()
268+ prefill_query_start_loc = query_start_loc [
269+ reqs_start :] - query_start_loc [reqs_start ]
261270
262271 prefill_metadata = AscendMLAPrefillMetadata (
263272 attn_mask = self .runner .attn_mask ,
@@ -268,6 +277,7 @@ def build(self,
268277 block_table = block_table [reqs_start :, ...],
269278 max_query_len = max_query_len ,
270279 max_seq_lens = max_seq_lens ,
280+ query_start_loc = prefill_query_start_loc ,
271281 )
272282
273283 decode_metadata = None
@@ -324,6 +334,9 @@ def build(self,
324334 attn_state = self .runner .attn_state ,
325335 prefill = prefill_metadata ,
326336 decode = decode_metadata ,
337+ query_start_loc = query_start_loc ,
338+ block_tables = block_table ,
339+ seq_lens = seq_lens ,
327340 )
328341
329342
@@ -373,6 +386,12 @@ def __init__(
373386 self .qk_rope_head_dim = qk_rope_head_dim
374387 self .qk_head_dim = qk_head_dim
375388 self .v_head_dim = v_head_dim
389+ # TODO: below padding should be removed after kernel is ready
390+ # we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
391+ # and slice the final result to guarantee its functionality.
392+ self .padding_head_dim = (
393+ (self .qk_nope_head_dim + self .qk_rope_head_dim - 1 ) // 128 +
394+ 1 ) * 128
376395
377396 # Hack for V1 for now to avoid torch library overhead (since we are
378397 # already inside an attention custom op), pull out the forward
@@ -470,11 +489,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
470489 [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
471490
472491 # Convert from (L, N, V) to (N, L, V)
473- self .W_UV = W_UV .transpose (0 , 1 ). contiguous ()
492+ self .W_UV = W_UV .transpose (0 , 1 )
474493 # Convert from (L, N, P) to (N, P, L)
475- self .W_UK_T = W_UK .permute (1 , 2 , 0 ).contiguous ()
476- self .W_UV .data = torch_npu .npu_format_cast (self .W_UV .data , 29 )
477- self .W_UK_T .data = torch_npu .npu_format_cast (self .W_UK_T .data , 29 )
494+ self .W_UK_T = W_UK .permute (1 , 2 , 0 )
478495
479496 def _forward_prefill (
480497 self ,
@@ -514,7 +531,7 @@ def _forward_prefill(
514531 elif attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
515532 attn_output = torch .empty (num_tokens ,
516533 self .num_heads ,
517- self .v_head_dim ,
534+ self .padding_head_dim ,
518535 dtype = query .dtype ,
519536 device = query .device )
520537 k_nope , value = self .kv_b_proj (kv_c_normed )[0 ].view (
@@ -523,17 +540,31 @@ def _forward_prefill(
523540 [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
524541 key = torch .cat ((k_nope , k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))),
525542 dim = - 1 )
543+ pad_query = torch .nn .functional .pad (query , [
544+ 0 , self .padding_head_dim - self .qk_rope_head_dim -
545+ self .qk_nope_head_dim
546+ ],
547+ value = 0 )
548+ pad_key = torch .nn .functional .pad (key , [
549+ 0 , self .padding_head_dim - self .qk_rope_head_dim -
550+ self .qk_nope_head_dim
551+ ],
552+ value = 0 )
553+ pad_value = torch .nn .functional .pad (
554+ value , [0 , self .padding_head_dim - self .v_head_dim ], value = 0 )
526555 torch_npu ._npu_flash_attention (
527- query = query ,
528- key = key ,
529- value = value ,
556+ query = pad_query ,
557+ key = pad_key ,
558+ value = pad_value ,
530559 mask = attn_metadata .attn_mask ,
531560 seq_len = attn_metadata .prefill .context_lens ,
532561 scale_value = self .scale ,
533562 num_heads = self .num_heads ,
534563 num_kv_heads = self .num_heads ,
535564 out = attn_output )
536- attn_output = attn_output .view (- 1 , self .num_heads , self .v_head_dim )
565+ attn_output = attn_output .view (
566+ - 1 , self .num_heads ,
567+ self .padding_head_dim )[:, :, :self .v_head_dim ]
537568 else :
538569 raise RuntimeError (
539570 "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
0 commit comments