2121from vllm_ascend .multistream .context import get_multistream_comm_context
2222from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
2323from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
24- from vllm_ascend .utils import npu_stream_switch , npu_wait_tensor
24+ from vllm_ascend .utils import npu_prefetch , npu_stream_switch , npu_wait_tensor
2525from vllm_ascend .worker .npu_input_batch import InputBatch
2626
2727if TYPE_CHECKING :
@@ -579,13 +579,18 @@ def __init__(
579579 " please make sure after the tensor parallel split, num_heads / num_kv_heads in "
580580 "{32, 64, 128}." )
581581
582- def _v_up_proj_and_o_proj (self , x ):
582+ def _v_up_proj_and_o_proj (self , x , enable_multistream_mla : bool = False ):
583583 # Convert from (B, N, L) to (N, B, L)
584584 x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
585585 # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
586586 x = torch .bmm (x , self .W_UV )
587587 # Convert from (N, B, V) to (B, N * V)
588588 x = x .transpose (0 , 1 ).reshape (- 1 , self .num_heads * self .v_head_dim )
589+ MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
590+ npu_prefetch (self .o_proj .weight ,
591+ x ,
592+ max_size = MAX_O_PROJ_PREFETCH_SIZE ,
593+ enabled = enable_multistream_mla )
589594 return self .o_proj (x , is_prefill = False )[0 ]
590595
591596 # Return `ql_nope`, `q_pe`
@@ -864,7 +869,6 @@ def exec_kv(
864869 sin : torch .Tensor ,
865870 kv_cache : Tuple ,
866871 slots : torch .Tensor ,
867- enable_multistream_mla : bool = False ,
868872 ):
869873
870874 B = hidden_states .shape [0 ]
@@ -874,21 +878,18 @@ def exec_kv(
874878 # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
875879 kv = kv .view (B , N , S , self .kv_lora_rank + self .qk_rope_head_dim )
876880 cache_mode = "PA_NZ" if self .enable_kv_nz else "PA"
877- with npu_stream_switch ("mla_secondary" ,
878- 0 ,
879- enabled = enable_multistream_mla ):
880- k_pe , k_nope , _ , _ = torch_npu .npu_kv_rmsnorm_rope_cache (
881- kv ,
882- self .kv_a_layernorm .weight ,
883- cos ,
884- sin ,
885- slots .to (torch .int64 ),
886- kv_cache [1 ],
887- kv_cache [0 ],
888- epsilon = self .kv_a_layernorm .variance_epsilon ,
889- cache_mode = cache_mode ,
890- )
891- return k_pe , k_nope
881+ k_pe , k_nope , _ , _ = torch_npu .npu_kv_rmsnorm_rope_cache (
882+ kv ,
883+ self .kv_a_layernorm .weight ,
884+ cos ,
885+ sin ,
886+ slots .to (torch .int64 ),
887+ kv_cache [1 ],
888+ kv_cache [0 ],
889+ epsilon = self .kv_a_layernorm .variance_epsilon ,
890+ cache_mode = cache_mode ,
891+ )
892+ return k_pe , k_nope , kv
892893
893894 def exec_kv_prefill (
894895 self ,
@@ -940,6 +941,7 @@ def _forward_decode(
940941 k_pe : torch .Tensor ,
941942 kv_c_and_k_pe_cache : torch .Tensor ,
942943 attn_metadata : AscendMLAMetadata ,
944+ enable_multistream_mla : bool = False ,
943945 ) -> torch .Tensor :
944946 decode_meta = attn_metadata .decode
945947 assert decode_meta is not None
@@ -1020,7 +1022,8 @@ def _forward_decode(
10201022 out = attn_output )
10211023 current_ms_metadata = get_multistream_comm_context ()
10221024 if current_ms_metadata is None :
1023- return self ._v_up_proj_and_o_proj (attn_output )
1025+ return self ._v_up_proj_and_o_proj (attn_output ,
1026+ enable_multistream_mla )
10241027 else :
10251028 current_ms_metadata .before_comm_event .record ()
10261029 with torch .npu .stream (current_ms_metadata .comm_stream ):
@@ -1037,6 +1040,7 @@ def forward(
10371040 attn_metadata : M ,
10381041 output : Optional [torch .Tensor ] = None ,
10391042 enable_multistream_mla : bool = False ,
1043+ ckq : Optional [torch .Tensor ] = None ,
10401044 ) -> torch .Tensor :
10411045 assert output is not None , "Output tensor must be provided."
10421046 if attn_metadata is None :
@@ -1091,6 +1095,15 @@ def forward(
10911095 sin = sin [attn_metadata .decode .input_positions ]
10921096 cos = cos [:, None , None , :]
10931097 sin = sin [:, None , None , :]
1098+ with npu_stream_switch ("mla_secondary" ,
1099+ 0 ,
1100+ enabled = enable_multistream_mla ):
1101+ npu_wait_tensor (hidden_states_or_kv_c_normed ,
1102+ ckq ,
1103+ enabled = enable_multistream_mla )
1104+ decode_k_pe , decode_k_nope , decode_kv = self .exec_kv (
1105+ hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1106+ attn_metadata .slot_mapping )
10941107 # Without explicitly controlling the order, IndexByTensor operations
10951108 # would be placed after `matmul W_KV_T` hindering the overlapping of
10961109 # KvRmsNormRopeCache and SingleRope.
@@ -1100,12 +1113,13 @@ def forward(
11001113 npu_wait_tensor (decode_hs_or_q_c ,
11011114 sin ,
11021115 enabled = enable_multistream_mla )
1116+ npu_wait_tensor (decode_hs_or_q_c ,
1117+ decode_kv ,
1118+ enabled = enable_multistream_mla )
1119+
11031120 decode_ql_nope , decode_q_pe = \
11041121 self ._q_proj_and_k_up_proj (decode_hs_or_q_c )
11051122 if self .running_in_graph :
1106- decode_k_pe , decode_k_nope = self .exec_kv (
1107- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1108- attn_metadata .slot_mapping , enable_multistream_mla )
11091123 with npu_stream_switch ("mla_secondary" ,
11101124 0 ,
11111125 enabled = enable_multistream_mla ):
@@ -1194,7 +1208,8 @@ def forward(
11941208 if self .running_in_graph :
11951209 return self ._forward_decode (decode_ql_nope , decode_q_pe ,
11961210 decode_k_nope , decode_k_pe ,
1197- kv_cache , attn_metadata )
1211+ kv_cache , attn_metadata ,
1212+ enable_multistream_mla )
11981213 else :
11991214 output_decode = self ._forward_decode (decode_ql_nope ,
12001215 decode_q_pe ,
0 commit comments