1313 LinearBase , RowParallelLinear ,
1414 UnquantizedLinearMethod )
1515from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
16+ from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
1617
1718from vllm_ascend .attention .attention_v1 import AscendAttentionState
1819from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
20+ from vllm_ascend .utils import vllm_version_is
1921from vllm_ascend .worker .model_runner_v1 import NPUModelRunner
2022
2123if TYPE_CHECKING :
2224 from vllm .v1 .core .sched .output import SchedulerOutput
2325 from vllm .v1 .worker .gpu_input_batch import InputBatch
2426
27+ if vllm_version_is ("main" ):
28+ from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
29+
2530
2631class AscendMLABackend (AttentionBackend ):
2732
@@ -57,6 +62,7 @@ class AscendMLAPrefillMetadata:
5762 seq_lens : list [int ]
5863 context_lens : torch .Tensor
5964 input_positions : torch .Tensor
65+ query_start_loc : torch .Tensor
6066 block_table : torch .Tensor
6167 max_query_len : int
6268 max_seq_lens : int
@@ -90,6 +96,9 @@ class AscendMLAMetadata:
9096
9197 num_actual_tokens : int # Number of tokens excluding padding.
9298 slot_mapping : torch .Tensor
99+ query_start_loc : torch .Tensor
100+ seq_lens : torch .Tensor
101+ block_tables : torch .Tensor
93102
94103 # New for MLA (compared to FlashAttention)
95104 # For handling prefill decode split
@@ -231,6 +240,7 @@ def build(self,
231240 num_actual_tokens : int ,
232241 max_query_len : int ,
233242 common_prefix_len : Optional [int ] = None ,
243+ common_attn_metadata : CommonAttentionMetadata = None ,
234244 graph_pad_size : int = - 1 ) -> AscendMLAMetadata :
235245 assert self ._num_decodes + self ._num_prefills == num_reqs
236246
@@ -245,6 +255,7 @@ def build(self,
245255 input_positions = self .runner .positions_cpu [:num_actual_tokens ].to (
246256 device , non_blocking = True ).long ()
247257
258+ query_start_loc = common_attn_metadata .query_start_loc
248259 seq_lens_cpu = self .runner .seq_lens_cpu [:num_reqs ]
249260 query_lens = seq_lens_cpu - self .runner .input_batch .num_computed_tokens_cpu_tensor [:
250261 num_reqs ]
@@ -258,6 +269,8 @@ def build(self,
258269 tokens_start = self ._num_decode_tokens
259270 max_query_len = query_lens [tokens_start :].max ().item ()
260271 max_seq_lens = seq_lens [tokens_start :].max ().item ()
272+ prefill_query_start_loc = query_start_loc [
273+ reqs_start :] - query_start_loc [reqs_start ]
261274
262275 prefill_metadata = AscendMLAPrefillMetadata (
263276 attn_mask = self .runner .attn_mask ,
@@ -268,6 +281,7 @@ def build(self,
268281 block_table = block_table [reqs_start :, ...],
269282 max_query_len = max_query_len ,
270283 max_seq_lens = max_seq_lens ,
284+ query_start_loc = prefill_query_start_loc ,
271285 )
272286
273287 decode_metadata = None
@@ -324,6 +338,9 @@ def build(self,
324338 attn_state = self .runner .attn_state ,
325339 prefill = prefill_metadata ,
326340 decode = decode_metadata ,
341+ query_start_loc = query_start_loc ,
342+ block_tables = block_table ,
343+ seq_lens = seq_lens ,
327344 )
328345
329346
@@ -373,6 +390,12 @@ def __init__(
373390 self .qk_rope_head_dim = qk_rope_head_dim
374391 self .qk_head_dim = qk_head_dim
375392 self .v_head_dim = v_head_dim
393+ # TODO: below padding should be removed after kernel is ready
394+ # we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
395+ # and slice the final result to guarantee its functionality.
396+ self .padding_head_dim = (
397+ (self .qk_nope_head_dim + self .qk_rope_head_dim - 1 ) // 128 +
398+ 1 ) * 128
376399
377400 # Hack for V1 for now to avoid torch library overhead (since we are
378401 # already inside an attention custom op), pull out the forward
@@ -470,11 +493,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
470493 [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
471494
472495 # Convert from (L, N, V) to (N, L, V)
473- self .W_UV = W_UV .transpose (0 , 1 ). contiguous ()
496+ self .W_UV = W_UV .transpose (0 , 1 )
474497 # Convert from (L, N, P) to (N, P, L)
475- self .W_UK_T = W_UK .permute (1 , 2 , 0 ).contiguous ()
476- self .W_UV .data = torch_npu .npu_format_cast (self .W_UV .data , 29 )
477- self .W_UK_T .data = torch_npu .npu_format_cast (self .W_UK_T .data , 29 )
498+ self .W_UK_T = W_UK .permute (1 , 2 , 0 )
478499
479500 def _forward_prefill (
480501 self ,
@@ -514,7 +535,7 @@ def _forward_prefill(
514535 elif attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
515536 attn_output = torch .empty (num_tokens ,
516537 self .num_heads ,
517- self .v_head_dim ,
538+ self .padding_head_dim ,
518539 dtype = query .dtype ,
519540 device = query .device )
520541 k_nope , value = self .kv_b_proj (kv_c_normed )[0 ].view (
@@ -523,17 +544,31 @@ def _forward_prefill(
523544 [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
524545 key = torch .cat ((k_nope , k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))),
525546 dim = - 1 )
547+ pad_query = torch .nn .functional .pad (query , [
548+ 0 , self .padding_head_dim - self .qk_rope_head_dim -
549+ self .qk_nope_head_dim
550+ ],
551+ value = 0 )
552+ pad_key = torch .nn .functional .pad (key , [
553+ 0 , self .padding_head_dim - self .qk_rope_head_dim -
554+ self .qk_nope_head_dim
555+ ],
556+ value = 0 )
557+ pad_value = torch .nn .functional .pad (
558+ value , [0 , self .padding_head_dim - self .v_head_dim ], value = 0 )
526559 torch_npu ._npu_flash_attention (
527- query = query ,
528- key = key ,
529- value = value ,
560+ query = pad_query ,
561+ key = pad_key ,
562+ value = pad_value ,
530563 mask = attn_metadata .attn_mask ,
531564 seq_len = attn_metadata .prefill .context_lens ,
532565 scale_value = self .scale ,
533566 num_heads = self .num_heads ,
534567 num_kv_heads = self .num_heads ,
535568 out = attn_output )
536- attn_output = attn_output .view (- 1 , self .num_heads , self .v_head_dim )
569+ attn_output = attn_output .view (
570+ - 1 , self .num_heads ,
571+ self .padding_head_dim )[:, :, :self .v_head_dim ]
537572 else :
538573 raise RuntimeError (
539574 "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
0 commit comments