1818from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
1919from vllm_ascend .worker .model_runner_v1 import NPUModelRunner
2020
21+ from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
22+ from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
23+
2124if TYPE_CHECKING :
2225 from vllm .v1 .core .sched .output import SchedulerOutput
2326 from vllm .v1 .worker .gpu_input_batch import InputBatch
@@ -100,6 +103,8 @@ class AscendMLAMetadata:
100103 # For logging.
101104 num_input_tokens : int = 0 # Number of tokens including padding.
102105
106+ query_lens : list [int ] = None
107+ seq_lens : torch .Tensor = None
103108 # The dimension of the attention heads
104109 head_dim : Optional [int ] = None
105110 attn_mask : torch .Tensor = None
@@ -118,6 +123,16 @@ def __post_init__(self):
118123 # f"Only {supported_head_sizes} are supported for head_dim,",
119124 # f"received {self.head_dim}.")
120125
126+ def split_metadata_for_multistream (
127+ self ,
128+ ms_split_config : MSAttentionMetadataSplitConfig ,
129+ ) -> list ["AscendMLAMetadata" ]:
130+ """Split metadata for multi-stream with AscendMLAMetadata"""
131+ return model_input_split_v1_mla_attn (
132+ ms_split_config = ms_split_config ,
133+ attn_metadata = self ,
134+ _metadata_cls = AscendMLAMetadata ,
135+ )
121136
122137M = TypeVar ("M" , bound = AscendMLAMetadata )
123138
@@ -315,6 +330,8 @@ def build(self,
315330
316331 return self .metadata_cls ( # type: ignore
317332 num_actual_tokens = num_actual_tokens ,
333+ query_lens = query_lens .tolist (),
334+ seq_lens = seq_lens ,
318335 slot_mapping = slot_mapping ,
319336 head_dim = self .runner .model_config .get_head_size (),
320337 num_decodes = self ._num_decodes ,
@@ -783,16 +800,34 @@ def forward(
783800 key_cache = kv_cache ,
784801 slot_indices = attn_metadata .slot_mapping .flatten ())
785802 if has_prefill :
786- output [num_decode_tokens :] = self ._forward_prefill (
803+ # FIX: aicore move/copy should be also placed on the comm stream in dbo,
804+ # otherwise it may affect the accuracy or disturb the overlap of next stage
805+ # TODO: use an elegant way here to avoid it
806+ output_prefill = self ._forward_prefill (
787807 prefill_q , prefill_k_c_normed , prefill_k_pe , kv_cache ,
788808 attn_metadata )
809+ from vllm .multistream .context import get_multistream_comm_context
810+ current_ms_metadata = get_multistream_comm_context ()
811+ if current_ms_metadata is not None :
812+ with torch .npu .stream (current_ms_metadata .comm_stream ):
813+ output [num_decode_tokens :] = output_prefill
814+ current_ms_metadata .after_comm_event .record ()
815+ else :
816+ output [num_decode_tokens :] = output_prefill
789817 if has_decode :
790818 if self .running_in_graph :
791819 return self ._forward_decode (decode_ql_nope , decode_q_pe ,
792820 decode_k_nope , decode_k_pe ,
793821 kv_cache , attn_metadata )
794822 else :
795- output [:num_decode_tokens ] = self ._forward_decode (
796- decode_ql_nope , decode_q_pe , decode_k_nope , decode_k_pe ,
797- kv_cache , attn_metadata )
823+ from vllm .multistream .context import get_multistream_comm_context
824+ current_ms_metadata = get_multistream_comm_context ()
825+ output_decode = self ._forward_decode (
826+ decode_ql_nope , decode_q_pe , decode_k_nope , decode_k_pe ,
827+ kv_cache , attn_metadata )
828+ if current_ms_metadata is not None :
829+ with torch .npu .stream (current_ms_metadata .comm_stream ):
830+ output [:num_decode_tokens ] = output_decode
831+ else :
832+ output [:num_decode_tokens ] = output_decode
798833 return output_padded
0 commit comments