1313
1414from vllm_ascend .ascend_config import get_ascend_config
1515from vllm_ascend .attention .attention_v1 import AscendAttentionState
16- import vllm_ascend .envs as envs_ascend
16+ from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
17+ from vllm_ascend .multistream .context import get_multistream_comm_context
18+ from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
1719from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
1820
1921if TYPE_CHECKING :
@@ -118,6 +120,7 @@ class AscendMLAMetadata:
118120
119121 with_prefill_across_dp : bool = False
120122
123+ query_lens : Optional [list [int ]] = None
121124 # The dimension of the attention heads
122125 head_dim : Optional [int ] = None
123126 attn_mask : torch .Tensor = None
@@ -136,6 +139,17 @@ def __post_init__(self):
136139 # f"Only {supported_head_sizes} are supported for head_dim,",
137140 # f"received {self.head_dim}.")
138141
142+ def split_metadata_for_multistream (
143+ self ,
144+ ms_split_config : MSAttentionMetadataSplitConfig ,
145+ ) -> list ["AscendMLAMetadata" ]:
146+ """Split metadata for multi-stream with AscendMLAMetadata"""
147+ return model_input_split_v1_mla_attn (
148+ ms_split_config = ms_split_config ,
149+ attn_metadata = self ,
150+ _metadata_cls = AscendMLAMetadata ,
151+ )
152+
139153
140154M = TypeVar ("M" , bound = AscendMLAMetadata )
141155
@@ -387,6 +401,7 @@ def build(
387401
388402 return self .metadata_cls ( # type: ignore
389403 num_actual_tokens = num_actual_tokens ,
404+ query_lens = query_lens .tolist (),
390405 slot_mapping = slot_mapping ,
391406 head_dim = self .runner .model_config .get_head_size (),
392407 num_decodes = self ._num_decodes ,
@@ -444,9 +459,9 @@ def __init__(
444459 self .kv_a_proj_with_mqa = kwargs .get ('kv_a_proj_with_mqa' , None )
445460 self .kv_a_layernorm = kwargs .get ('kv_a_layernorm' , None )
446461
447- self .enable_kv_nz = envs_ascend .VLLM_ENABLE_KV_NZ
448462 ascend_config = get_ascend_config ()
449463 self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
464+ self .enable_kv_nz = ascend_config .torchair_graph_config .enable_kv_nz
450465
451466 def _v_up_proj_and_o_proj (self , x ):
452467 # Convert from (B, N, L) to (N, B, L)
@@ -587,7 +602,15 @@ def _forward_prefill(
587602 )
588603 attn_output = attn_output .reshape (
589604 [num_tokens , self .num_heads * self .v_head_dim ])
590- return self .o_proj (attn_output )[0 ]
605+
606+ current_ms_metadata = get_multistream_comm_context ()
607+ if current_ms_metadata is None :
608+ return self .o_proj (attn_output )[0 ]
609+ else :
610+ current_ms_metadata .before_comm_event .record ()
611+ with torch .npu .stream (current_ms_metadata .comm_stream ):
612+ current_ms_metadata .before_comm_event .wait ()
613+ return self .o_proj (attn_output )[0 ]
591614
592615 def exec_kv (
593616 self ,
@@ -731,7 +754,14 @@ def _forward_decode(
731754 context_lens = attn_metadata .decode .seq_lens , # type:ignore
732755 mla_vheadsize = self .kv_lora_rank ,
733756 out = attn_output )
734- return self ._v_up_proj_and_o_proj (attn_output )
757+ current_ms_metadata = get_multistream_comm_context ()
758+ if current_ms_metadata is None :
759+ return self ._v_up_proj_and_o_proj (attn_output )
760+ else :
761+ current_ms_metadata .before_comm_event .record ()
762+ with torch .npu .stream (current_ms_metadata .comm_stream ):
763+ current_ms_metadata .before_comm_event .wait ()
764+ return self ._v_up_proj_and_o_proj (attn_output )
735765
736766 def forward (
737767 self ,
@@ -863,16 +893,38 @@ def forward(
863893 key_cache = kv_cache ,
864894 slot_indices = attn_metadata .slot_mapping .flatten ())
865895 if has_prefill :
866- output [num_decode_tokens :] = self ._forward_prefill (
867- prefill_q , prefill_k_c_normed , prefill_k_pe , kv_cache ,
868- attn_metadata )
896+ # FIX: aicore move should be also placed on the comm stream in dbo,
897+ # otherwise it may affect the accuracy
898+ # TODO: use an elegant way to overlap
899+ output_prefill = self ._forward_prefill (prefill_q ,
900+ prefill_k_c_normed ,
901+ prefill_k_pe , kv_cache ,
902+ attn_metadata )
903+ current_ms_metadata = get_multistream_comm_context ()
904+ if current_ms_metadata is not None :
905+ with torch .npu .stream (current_ms_metadata .comm_stream ):
906+ output [num_decode_tokens :] = output_prefill
907+ current_ms_metadata .after_comm_event .record ()
908+ else :
909+ output [num_decode_tokens :] = output_prefill
910+
869911 if has_decode :
870912 if self .running_in_graph :
871913 return self ._forward_decode (decode_ql_nope , decode_q_pe ,
872914 decode_k_nope , decode_k_pe ,
873915 kv_cache , attn_metadata )
874916 else :
875- output [:num_decode_tokens ] = self ._forward_decode (
876- decode_ql_nope , decode_q_pe , decode_k_nope , decode_k_pe ,
877- kv_cache , attn_metadata )
917+ output_decode = self ._forward_decode (decode_ql_nope ,
918+ decode_q_pe ,
919+ decode_k_nope ,
920+ decode_k_pe , kv_cache ,
921+ attn_metadata )
922+ current_ms_metadata = get_multistream_comm_context ()
923+ if current_ms_metadata is not None :
924+ with torch .npu .stream (current_ms_metadata .comm_stream ):
925+ output [:num_decode_tokens ] = output_decode
926+ current_ms_metadata .after_comm_event .record ()
927+ else :
928+ output [:num_decode_tokens ] = output_decode
929+
878930 return output_padded
0 commit comments