|
18 | 18 | from vllm_ascend.multistream.context import get_multistream_comm_context |
19 | 19 | from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn |
20 | 20 | from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla |
| 21 | +from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor |
21 | 22 |
|
22 | 23 | if TYPE_CHECKING: |
23 | 24 | from vllm.v1.core.sched.output import SchedulerOutput |
@@ -475,6 +476,9 @@ def __init__( |
475 | 476 |
|
476 | 477 | ascend_config = get_ascend_config() |
477 | 478 | self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled |
| 479 | + self.enable_multistream_mla = \ |
| 480 | + ascend_config.torchair_graph_config.enable_multistream_mla |
| 481 | + |
478 | 482 | # Adapt torch air graph mode with spec decoding. |
479 | 483 | speculative_config = get_current_vllm_config().speculative_config |
480 | 484 | if speculative_config is not None: |
@@ -648,17 +652,20 @@ def exec_kv( |
648 | 652 | kv = self.kv_a_proj_with_mqa(hidden_states)[0] |
649 | 653 | # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] |
650 | 654 | kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) |
651 | | - k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( |
652 | | - kv, |
653 | | - self.kv_a_layernorm.weight, |
654 | | - cos, |
655 | | - sin, |
656 | | - slots.to(torch.int64), |
657 | | - kv_cache[1], |
658 | | - kv_cache[0], |
659 | | - epsilon=self.kv_a_layernorm.variance_epsilon, |
660 | | - cache_mode="PA", |
661 | | - ) |
| 655 | + with npu_stream_switch("mla_secondary", |
| 656 | + 0, |
| 657 | + enabled=self.enable_multistream_mla): |
| 658 | + k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( |
| 659 | + kv, |
| 660 | + self.kv_a_layernorm.weight, |
| 661 | + cos, |
| 662 | + sin, |
| 663 | + slots.to(torch.int64), |
| 664 | + kv_cache[1], |
| 665 | + kv_cache[0], |
| 666 | + epsilon=self.kv_a_layernorm.variance_epsilon, |
| 667 | + cache_mode="PA", |
| 668 | + ) |
662 | 669 | return k_pe, k_nope |
663 | 670 |
|
664 | 671 | def rope_single( |
@@ -810,23 +817,38 @@ def forward( |
810 | 817 | if has_decode: |
811 | 818 | decode_k_nope = None |
812 | 819 | assert attn_metadata.decode is not None |
| 820 | + if self.running_in_graph: |
| 821 | + with npu_stream_switch("mla_secondary", |
| 822 | + 0, |
| 823 | + enabled=self.enable_multistream_mla): |
| 824 | + seq_len = self.rotary_emb.max_position_embeddings |
| 825 | + cos = self.rotary_emb.cos_cached[:seq_len].to( |
| 826 | + dtype=decode_hs_or_q_c.dtype) |
| 827 | + sin = self.rotary_emb.sin_cached[:seq_len].to( |
| 828 | + dtype=decode_hs_or_q_c.dtype) |
| 829 | + cos = cos[attn_metadata.decode.input_positions] |
| 830 | + sin = sin[attn_metadata.decode.input_positions] |
| 831 | + cos = cos[:, None, None, :] |
| 832 | + sin = sin[:, None, None, :] |
| 833 | + npu_wait_tensor(decode_hs_or_q_c, |
| 834 | + cos, |
| 835 | + enabled=self.enable_multistream_mla) |
| 836 | + npu_wait_tensor(decode_hs_or_q_c, |
| 837 | + sin, |
| 838 | + enabled=self.enable_multistream_mla) |
813 | 839 | decode_ql_nope, decode_q_pe = \ |
814 | 840 | self._q_proj_and_k_up_proj(decode_hs_or_q_c) |
815 | 841 | if self.running_in_graph: |
816 | | - seq_len = self.rotary_emb.max_position_embeddings |
817 | | - cos = self.rotary_emb.cos_cached[:seq_len].to( |
818 | | - dtype=decode_q_pe.dtype) |
819 | | - sin = self.rotary_emb.sin_cached[:seq_len].to( |
820 | | - dtype=decode_q_pe.dtype) |
821 | | - cos = cos[attn_metadata.decode.input_positions] |
822 | | - sin = sin[attn_metadata.decode.input_positions] |
823 | | - cos = cos[:, None, None, :] |
824 | | - sin = sin[:, None, None, :] |
825 | | - |
826 | | - decode_q_pe = self.rope_single(decode_q_pe, cos, sin) |
827 | 842 | decode_k_pe, decode_k_nope = self.exec_kv( |
828 | 843 | hidden_states_or_kv_c_normed, cos, sin, kv_cache, |
829 | 844 | attn_metadata.slot_mapping) |
| 845 | + with npu_stream_switch("mla_secondary", |
| 846 | + 0, |
| 847 | + enabled=self.enable_multistream_mla): |
| 848 | + npu_wait_tensor(decode_q_pe, |
| 849 | + decode_k_pe, |
| 850 | + enabled=self.enable_multistream_mla) |
| 851 | + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) |
830 | 852 | else: |
831 | 853 | decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( |
832 | 854 | attn_metadata.decode.input_positions, |
|
0 commit comments