|
19 | 19 | from vllm_ascend.multistream.context import get_multistream_comm_context |
20 | 20 | from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn |
21 | 21 | from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla |
22 | | -from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor |
23 | 22 |
|
24 | 23 | if TYPE_CHECKING: |
25 | 24 | from vllm.v1.core.sched.output import SchedulerOutput |
@@ -482,9 +481,6 @@ def __init__( |
482 | 481 | ascend_config = get_ascend_config() |
483 | 482 | self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled |
484 | 483 | self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz |
485 | | - self.enable_multistream_mla = \ |
486 | | - ascend_config.torchair_graph_config.enable_multistream_mla |
487 | | - |
488 | 484 | # Adapt torch air graph mode with spec decoding. |
489 | 485 | speculative_config = get_current_vllm_config().speculative_config |
490 | 486 | if speculative_config is not None: |
@@ -668,20 +664,17 @@ def exec_kv( |
668 | 664 | # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] |
669 | 665 | kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) |
670 | 666 | cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" |
671 | | - with npu_stream_switch("mla_secondary", |
672 | | - 0, |
673 | | - enabled=self.enable_multistream_mla): |
674 | | - k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( |
675 | | - kv, |
676 | | - self.kv_a_layernorm.weight, |
677 | | - cos, |
678 | | - sin, |
679 | | - slots.to(torch.int64), |
680 | | - kv_cache[1], |
681 | | - kv_cache[0], |
682 | | - epsilon=self.kv_a_layernorm.variance_epsilon, |
683 | | - cache_mode=cache_mode, |
684 | | - ) |
| 667 | + k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( |
| 668 | + kv, |
| 669 | + self.kv_a_layernorm.weight, |
| 670 | + cos, |
| 671 | + sin, |
| 672 | + slots.to(torch.int64), |
| 673 | + kv_cache[1], |
| 674 | + kv_cache[0], |
| 675 | + epsilon=self.kv_a_layernorm.variance_epsilon, |
| 676 | + cache_mode=cache_mode, |
| 677 | + ) |
685 | 678 | return k_pe, k_nope |
686 | 679 |
|
687 | 680 | def exec_kv_prefill( |
@@ -874,38 +867,23 @@ def forward( |
874 | 867 | if has_decode: |
875 | 868 | decode_k_nope = None |
876 | 869 | assert attn_metadata.decode is not None |
| 870 | + decode_ql_nope, decode_q_pe = \ |
| 871 | + self._q_proj_and_k_up_proj(decode_hs_or_q_c) |
877 | 872 | if self.running_in_graph: |
878 | 873 | seq_len = self.rotary_emb.max_position_embeddings |
879 | 874 | cos = self.rotary_emb.cos_cached[:seq_len].to( |
880 | | - dtype=decode_hs_or_q_c.dtype) |
| 875 | + dtype=decode_q_pe.dtype) |
881 | 876 | sin = self.rotary_emb.sin_cached[:seq_len].to( |
882 | | - dtype=decode_hs_or_q_c.dtype) |
| 877 | + dtype=decode_q_pe.dtype) |
883 | 878 | cos = cos[attn_metadata.decode.input_positions] |
884 | 879 | sin = sin[attn_metadata.decode.input_positions] |
885 | 880 | cos = cos[:, None, None, :] |
886 | 881 | sin = sin[:, None, None, :] |
887 | | - # Without explicitly controlling the order, IndexByTensor operations |
888 | | - # would be placed after `matmul W_KV_T` hindering the overlapping of |
889 | | - # KvRmsNormRopeCache and SingleRope. |
890 | | - npu_wait_tensor(decode_hs_or_q_c, |
891 | | - cos, |
892 | | - enabled=self.enable_multistream_mla) |
893 | | - npu_wait_tensor(decode_hs_or_q_c, |
894 | | - sin, |
895 | | - enabled=self.enable_multistream_mla) |
896 | | - decode_ql_nope, decode_q_pe = \ |
897 | | - self._q_proj_and_k_up_proj(decode_hs_or_q_c) |
898 | | - if self.running_in_graph: |
| 882 | + |
| 883 | + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) |
899 | 884 | decode_k_pe, decode_k_nope = self.exec_kv( |
900 | 885 | hidden_states_or_kv_c_normed, cos, sin, kv_cache, |
901 | 886 | attn_metadata.slot_mapping) |
902 | | - with npu_stream_switch("mla_secondary", |
903 | | - 0, |
904 | | - enabled=self.enable_multistream_mla): |
905 | | - npu_wait_tensor(decode_q_pe, |
906 | | - decode_k_pe, |
907 | | - enabled=self.enable_multistream_mla) |
908 | | - decode_q_pe = self.rope_single(decode_q_pe, cos, sin) |
909 | 887 | else: |
910 | 888 | decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( |
911 | 889 | attn_metadata.decode.input_positions, |
|
0 commit comments