99 MLAAttentionImpl )
1010from vllm .attention .backends .utils import PAD_SLOT_ID
1111from vllm .config import get_current_vllm_config
12- from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
13- LinearBase , RowParallelLinear ,
12+ from vllm .model_executor .layers .linear import (LinearBase ,
1413 UnquantizedLinearMethod )
15- from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
1614
1715from vllm_ascend .attention .attention_v1 import AscendAttentionState
1816from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
@@ -117,6 +115,8 @@ class AscendMLAMetadata:
117115 # For logging.
118116 num_input_tokens : int = 0 # Number of tokens including padding.
119117
118+ with_prefill_across_dp : bool = False
119+
120120 # The dimension of the attention heads
121121 head_dim : Optional [int ] = None
122122 attn_mask : torch .Tensor = None
@@ -260,6 +260,10 @@ def build_dummy(self, num_reqs: int,
260260 PAD_SLOT_ID ,
261261 dtype = torch .int32 ,
262262 device = device )
263+ query_start_loc = torch .full ((num_reqs , ),
264+ - 1 ,
265+ dtype = torch .int32 ,
266+ device = device )
263267 decode_metadata = AscendMLADecodeMetadata (
264268 input_positions = input_positions ,
265269 block_table = block_table ,
@@ -278,15 +282,21 @@ def build_dummy(self, num_reqs: int,
278282 attn_state = AscendAttentionState .DecodeOnly ,
279283 prefill = None ,
280284 decode = decode_metadata ,
285+ query_start_loc = query_start_loc ,
286+ seq_lens = seq_lens ,
287+ block_tables = block_table ,
281288 )
282289
283- def build (self ,
284- num_reqs : int ,
285- num_actual_tokens : int ,
286- max_query_len : int ,
287- common_attn_metadata : CommonAttentionMetadata ,
288- common_prefix_len : Optional [int ] = None ,
289- graph_pad_size : int = - 1 ) -> AscendMLAMetadata :
290+ def build (
291+ self ,
292+ num_reqs : int ,
293+ num_actual_tokens : int ,
294+ max_query_len : int ,
295+ common_attn_metadata : CommonAttentionMetadata ,
296+ common_prefix_len : Optional [int ] = None ,
297+ graph_pad_size : int = - 1 ,
298+ with_prefill_across_dp : bool = False ,
299+ ) -> AscendMLAMetadata :
290300 assert self ._num_decodes + self ._num_prefills == num_reqs
291301
292302 # Note(simon): be careful about the CPU <> GPU memory movement in this
@@ -388,6 +398,7 @@ def build(self,
388398 query_start_loc = query_start_loc ,
389399 block_tables = block_table ,
390400 seq_lens = seq_lens ,
401+ with_prefill_across_dp = with_prefill_across_dp ,
391402 )
392403
393404
@@ -409,20 +420,7 @@ def __init__(
409420 blocksparse_params : Optional [dict [str , Any ]],
410421 logits_soft_cap : Optional [float ],
411422 attn_type : str ,
412- # MLA Specific Arguments
413- q_lora_rank : Optional [int ],
414- kv_lora_rank : int ,
415- qk_nope_head_dim : int ,
416- qk_rope_head_dim : int ,
417- qk_head_dim : int ,
418- v_head_dim : int ,
419- rotary_emb : RotaryEmbedding ,
420- # q_proj should be q_b_proj if q_lora_rank is not None, but from an
421- # attention backend perspective we rely on the layer to pass in the
422- # correct matrix
423- q_proj : ColumnParallelLinear ,
424- kv_b_proj : ColumnParallelLinear ,
425- o_proj : RowParallelLinear ,
423+ kv_sharing_target_layer_name : Optional [str ] = None ,
426424 ** kwargs ,
427425 ) -> None :
428426 self .num_heads = num_heads
@@ -431,25 +429,20 @@ def __init__(
431429 self .num_kv_heads = num_kv_heads
432430 self .kv_cache_dtype = kv_cache_dtype
433431
434- self .q_lora_rank = q_lora_rank
435- self .kv_lora_rank = kv_lora_rank
436- self .qk_nope_head_dim = qk_nope_head_dim
437- self .qk_rope_head_dim = qk_rope_head_dim
438- self .qk_head_dim = qk_head_dim
439- self .v_head_dim = v_head_dim
440-
441- # Hack for V1 for now to avoid torch library overhead (since we are
442- # already inside an attention custom op), pull out the forward
443- # method from the rotary embedding and call it directly
444- # TODO(lucas): we should probably find a cleaner way to do this
445- self .rotary_emb = rotary_emb
446-
447- self .q_proj = q_proj
448- self .kv_b_proj = kv_b_proj
449- self .o_proj = o_proj
450-
432+ # MLA Args
433+ self .q_lora_rank = kwargs ['q_lora_rank' ]
434+ self .kv_lora_rank = kwargs ['kv_lora_rank' ]
435+ self .qk_nope_head_dim = kwargs ['qk_nope_head_dim' ]
436+ self .qk_rope_head_dim = kwargs ['qk_rope_head_dim' ]
437+ self .qk_head_dim = kwargs ['qk_head_dim' ]
438+ self .v_head_dim = kwargs ['v_head_dim' ]
439+ self .rotary_emb = kwargs ['rotary_emb' ]
440+ self .q_proj = kwargs ['q_proj' ]
441+ self .kv_b_proj = kwargs ['kv_b_proj' ]
442+ self .o_proj = kwargs ['o_proj' ]
451443 self .kv_a_proj_with_mqa = kwargs .get ('kv_a_proj_with_mqa' , None )
452444 self .kv_a_layernorm = kwargs .get ('kv_a_layernorm' , None )
445+
453446 # Handle the differences between the flash_attn_varlen from flash_attn
454447 # and the one from vllm_flash_attn. The former is used on RoCM and the
455448 # latter has an additional parameter to control FA2 vs FA3
@@ -621,7 +614,7 @@ def exec_kv(
621614 kv = self .kv_a_proj_with_mqa (hidden_states )[0 ]
622615 # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
623616 kv = kv .view (B , N , S , self .kv_lora_rank + self .qk_rope_head_dim )
624- k_pe , k_nope , _ , _ = torch . ops . npu_inference .npu_kv_rmsnorm_rope_cache (
617+ k_pe , k_nope , _ , _ = torch_npu .npu_kv_rmsnorm_rope_cache (
625618 kv ,
626619 self .kv_a_layernorm .weight ,
627620 cos ,
@@ -643,7 +636,7 @@ def rope_single(
643636 B , N , D = x .shape
644637 S = 1
645638 x = x .view (B , N , S , D )
646- x = torch . ops . npu_inference .npu_interleave_rope (x , cos , sin )
639+ x = torch_npu .npu_interleave_rope (x , cos , sin )
647640 return x .view (B , N , D )
648641
649642 def _forward_decode (
@@ -766,6 +759,7 @@ def forward(
766759 sin = sin [attn_metadata .decode .input_positions ]
767760 cos = cos [:, None , None , :]
768761 sin = sin [:, None , None , :]
762+
769763 decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
770764 decode_k_pe , decode_k_nope = self .exec_kv (
771765 hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
0 commit comments