| 
17 | 17 | from vllm_ascend.multistream.context import get_multistream_comm_context  | 
18 | 18 | from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn  | 
19 | 19 | from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla  | 
 | 20 | +from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor  | 
20 | 21 | 
 
  | 
21 | 22 | if TYPE_CHECKING:  | 
22 | 23 |     from vllm.v1.core.sched.output import SchedulerOutput  | 
@@ -461,6 +462,8 @@ def __init__(  | 
461 | 462 | 
 
  | 
462 | 463 |         ascend_config = get_ascend_config()  | 
463 | 464 |         self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled  | 
 | 465 | +        self.enable_multistream_mla = \  | 
 | 466 | +            ascend_config.torchair_graph_config.enable_multistream_mla  | 
464 | 467 | 
 
  | 
465 | 468 |     def _v_up_proj_and_o_proj(self, x):  | 
466 | 469 |         # Convert from (B, N, L) to (N, B, L)  | 
@@ -626,17 +629,19 @@ def exec_kv(  | 
626 | 629 |         kv = self.kv_a_proj_with_mqa(hidden_states)[0]  | 
627 | 630 |         # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]  | 
628 | 631 |         kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)  | 
629 |  | -        k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(  | 
630 |  | -            kv,  | 
631 |  | -            self.kv_a_layernorm.weight,  | 
632 |  | -            cos,  | 
633 |  | -            sin,  | 
634 |  | -            slots.to(torch.int64),  | 
635 |  | -            kv_cache[1],  | 
636 |  | -            kv_cache[0],  | 
637 |  | -            epsilon=self.kv_a_layernorm.variance_epsilon,  | 
638 |  | -            cache_mode="PA",  | 
639 |  | -        )  | 
 | 632 | +        with npu_stream_switch("mla_secondary", 0,  | 
 | 633 | +                               enabled=self.enable_multistream_mla):  | 
 | 634 | +            k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(  | 
 | 635 | +                kv,  | 
 | 636 | +                self.kv_a_layernorm.weight,  | 
 | 637 | +                cos,  | 
 | 638 | +                sin,  | 
 | 639 | +                slots.to(torch.int64),  | 
 | 640 | +                kv_cache[1],  | 
 | 641 | +                kv_cache[0],  | 
 | 642 | +                epsilon=self.kv_a_layernorm.variance_epsilon,  | 
 | 643 | +                cache_mode="PA",  | 
 | 644 | +            )  | 
640 | 645 |         return k_pe, k_nope  | 
641 | 646 | 
 
  | 
642 | 647 |     def rope_single(  | 
@@ -769,20 +774,25 @@ def forward(  | 
769 | 774 |             decode_ql_nope, decode_q_pe = \  | 
770 | 775 |                 self._q_proj_and_k_up_proj(decode_hs_or_q_c)  | 
771 | 776 |             if self.running_in_graph:  | 
772 |  | -                seq_len = self.rotary_emb.max_position_embeddings  | 
773 |  | -                cos = self.rotary_emb.cos_cached[:seq_len].to(  | 
774 |  | -                    dtype=decode_q_pe.dtype)  | 
775 |  | -                sin = self.rotary_emb.sin_cached[:seq_len].to(  | 
776 |  | -                    dtype=decode_q_pe.dtype)  | 
777 |  | -                cos = cos[attn_metadata.decode.input_positions]  | 
778 |  | -                sin = sin[attn_metadata.decode.input_positions]  | 
779 |  | -                cos = cos[:, None, None, :]  | 
780 |  | -                sin = sin[:, None, None, :]  | 
781 |  | - | 
782 |  | -                decode_q_pe = self.rope_single(decode_q_pe, cos, sin)  | 
 | 777 | +                with npu_stream_switch("mla_secondary", 0,  | 
 | 778 | +                                       enabled=self.enable_multistream_mla):  | 
 | 779 | +                    seq_len = self.rotary_emb.max_position_embeddings  | 
 | 780 | +                    cos = self.rotary_emb.cos_cached[:seq_len].to(  | 
 | 781 | +                        dtype=decode_q_pe.dtype)  | 
 | 782 | +                    sin = self.rotary_emb.sin_cached[:seq_len].to(  | 
 | 783 | +                        dtype=decode_q_pe.dtype)  | 
 | 784 | +                    cos = cos[attn_metadata.decode.input_positions]  | 
 | 785 | +                    sin = sin[attn_metadata.decode.input_positions]  | 
 | 786 | +                    cos = cos[:, None, None, :]  | 
 | 787 | +                    sin = sin[:, None, None, :]  | 
783 | 788 |                 decode_k_pe, decode_k_nope = self.exec_kv(  | 
784 | 789 |                     hidden_states_or_kv_c_normed, cos, sin, kv_cache,  | 
785 | 790 |                     attn_metadata.slot_mapping)  | 
 | 791 | +                with npu_stream_switch("mla_secondary", 0,  | 
 | 792 | +                                       enabled=self.enable_multistream_mla):  | 
 | 793 | +                    npu_wait_tensor(decode_q_pe, decode_k_pe,  | 
 | 794 | +                                    self.enable_multistream_mla)  | 
 | 795 | +                    decode_q_pe = self.rope_single(decode_q_pe, cos, sin)  | 
786 | 796 |             else:  | 
787 | 797 |                 decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(  | 
788 | 798 |                     attn_metadata.decode.input_positions,  | 
 | 
0 commit comments