44import numpy as np
55import torch
66import torch_npu
7- from vllm_ascend .attention .attention_v1 import AscendAttentionState
8- from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
9- from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
10- from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
11- from vllm_ascend .worker .model_runner_v1 import NPUModelRunner
12-
137from vllm .attention .backends .abstract import (AttentionBackend , AttentionLayer ,
148 AttentionMetadata ,
159 MLAAttentionImpl )
2014 UnquantizedLinearMethod )
2115from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
2216
17+ from vllm_ascend .attention .attention_v1 import AscendAttentionState
18+ from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
19+ from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
20+ from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
21+ from vllm_ascend .worker .model_runner_v1 import NPUModelRunner
22+
2323if TYPE_CHECKING :
2424 from vllm .v1 .core .sched .output import SchedulerOutput
2525 from vllm .v1 .worker .gpu_input_batch import InputBatch
@@ -123,8 +123,8 @@ def __post_init__(self):
123123 # f"received {self.head_dim}.")
124124
125125 def split_metadata_for_multistream (
126- self ,
127- ms_split_config : MSAttentionMetadataSplitConfig ,
126+ self ,
127+ ms_split_config : MSAttentionMetadataSplitConfig ,
128128 ) -> list ["AscendMLAMetadata" ]:
129129 """Split metadata for multi-stream with AscendMLAMetadata"""
130130 return model_input_split_v1_mla_attn (
@@ -133,6 +133,7 @@ def split_metadata_for_multistream(
133133 _metadata_cls = AscendMLAMetadata ,
134134 )
135135
136+
136137M = TypeVar ("M" , bound = AscendMLAMetadata )
137138
138139
@@ -574,14 +575,14 @@ def _forward_prefill(
574575 )
575576 attn_output = attn_output .reshape (
576577 [num_tokens , self .num_heads * self .v_head_dim ])
577-
578+
578579 # A better way is to modify the communication ops or RowParallel Layer in vllm;
579580 from vllm_ascend .multistream .context import \
580581 get_multistream_comm_context
581- current_ms_metadata = get_multistream_comm_context ()
582+ current_ms_metadata = get_multistream_comm_context ()
582583 if current_ms_metadata is None :
583584 return self .o_proj (attn_output )[0 ]
584- else :
585+ else :
585586 current_ms_metadata .before_comm_event .record ()
586587 with torch .npu .stream (current_ms_metadata .comm_stream ):
587588 current_ms_metadata .before_comm_event .wait ()
@@ -687,16 +688,15 @@ def _forward_decode(
687688 out = attn_output )
688689 from vllm_ascend .multistream .context import \
689690 get_multistream_comm_context
690- current_ms_metadata = get_multistream_comm_context ()
691+ current_ms_metadata = get_multistream_comm_context ()
691692 if current_ms_metadata is None :
692693 return self ._v_up_proj_and_o_proj (attn_output )
693- else :
694+ else :
694695 current_ms_metadata .before_comm_event .record ()
695696 with torch .npu .stream (current_ms_metadata .comm_stream ):
696697 current_ms_metadata .before_comm_event .wait ()
697698 return self ._v_up_proj_and_o_proj (attn_output )
698699
699-
700700 def forward (
701701 self ,
702702 layer : AttentionLayer ,
@@ -820,14 +820,15 @@ def forward(
820820 key_cache = kv_cache ,
821821 slot_indices = attn_metadata .slot_mapping .flatten ())
822822 if has_prefill :
823- # FIX: aicore move should be also placed on the comm stream in dbo,
824- # otherwise it may affect the accuracy
823+ # FIX: aicore move should be also placed on the comm stream in dbo,
824+ # otherwise it may affect the accuracy
825825 # TODO: use an elegant way to overlap
826826 from vllm_ascend .multistream .context import \
827827 get_multistream_comm_context
828- output_prefill = self ._forward_prefill (
829- prefill_q , prefill_k_c_normed , prefill_k_pe , kv_cache ,
830- attn_metadata )
828+ output_prefill = self ._forward_prefill (prefill_q ,
829+ prefill_k_c_normed ,
830+ prefill_k_pe , kv_cache ,
831+ attn_metadata )
831832 current_ms_metadata = get_multistream_comm_context ()
832833 if current_ms_metadata is not None :
833834 with torch .npu .stream (current_ms_metadata .comm_stream ):
@@ -836,7 +837,6 @@ def forward(
836837 else :
837838 output [num_decode_tokens :] = output_prefill
838839
839-
840840 if has_decode :
841841 if self .running_in_graph :
842842 return self ._forward_decode (decode_ql_nope , decode_q_pe ,
@@ -845,16 +845,17 @@ def forward(
845845 else :
846846 from vllm_ascend .multistream .context import \
847847 get_multistream_comm_context
848- output_decode = self ._forward_decode (
849- decode_ql_nope , decode_q_pe , decode_k_nope , decode_k_pe ,
850- kv_cache , attn_metadata )
851- current_ms_metadata = get_multistream_comm_context ()
848+ output_decode = self ._forward_decode (decode_ql_nope ,
849+ decode_q_pe ,
850+ decode_k_nope ,
851+ decode_k_pe , kv_cache ,
852+ attn_metadata )
853+ current_ms_metadata = get_multistream_comm_context ()
852854 if current_ms_metadata is not None :
853855 with torch .npu .stream (current_ms_metadata .comm_stream ):
854856 output [:num_decode_tokens ] = output_decode
855857 current_ms_metadata .after_comm_event .record ()
856858 else :
857859 output [:num_decode_tokens ] = output_decode
858860
859-
860861 return output_padded
0 commit comments