|
13 | 13 | UnquantizedLinearMethod) |
14 | 14 | from vllm.utils import cdiv, round_down |
15 | 15 |
|
| 16 | +from vllm_ascend import envs |
16 | 17 | from vllm_ascend.ascend_config import get_ascend_config |
17 | 18 | from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV |
18 | 19 | from vllm_ascend.attention.attention_v1 import AscendAttentionState |
@@ -933,18 +934,12 @@ def _forward_decode( |
933 | 934 | q_pe: torch.Tensor, |
934 | 935 | k_nope: torch.Tensor, |
935 | 936 | k_pe: torch.Tensor, |
936 | | - kv_c_and_k_pe_cache: torch.Tensor, |
| 937 | + kv_c_and_k_pe_cache: Tuple[torch.Tensor], |
937 | 938 | attn_metadata: AscendMLAMetadata, |
938 | 939 | ) -> torch.Tensor: |
939 | 940 | decode_meta = attn_metadata.decode |
940 | 941 | assert decode_meta is not None |
941 | | - |
942 | | - q = torch.cat([q_nope, q_pe], dim=-1) |
943 | | - num_tokens = q.size(0) |
944 | | - attn_output = torch.empty( |
945 | | - [num_tokens, self.num_heads, self.kv_lora_rank], |
946 | | - dtype=q.dtype, |
947 | | - device=q.device) |
| 942 | + num_tokens = q_nope.size(0) |
948 | 943 | if self.running_in_graph: |
949 | 944 | # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] |
950 | 945 | if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: |
@@ -1003,16 +998,35 @@ def _forward_decode( |
1003 | 998 | actual_seq_lengths_kv=decode_meta.seq_lens_list, |
1004 | 999 | ) |
1005 | 1000 | else: |
1006 | | - torch_npu._npu_paged_attention_mla( |
1007 | | - query=q, |
1008 | | - key_cache=kv_c_and_k_pe_cache, |
1009 | | - num_kv_heads=self.num_kv_heads, |
1010 | | - num_heads=self.num_heads, |
1011 | | - scale_value=self.scale, |
1012 | | - block_table=attn_metadata.decode.block_table, # type:ignore |
1013 | | - context_lens=attn_metadata.decode.seq_lens, # type:ignore |
1014 | | - mla_vheadsize=self.kv_lora_rank, |
1015 | | - out=attn_output) |
| 1001 | + # The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will |
| 1002 | + # be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become |
| 1003 | + # public available |
| 1004 | + assert len(kv_c_and_k_pe_cache) > 1 |
| 1005 | + if envs.VLLM_ASCEND_MLA_PA: |
| 1006 | + attn_output = torch_npu.atb.npu_multi_head_latent_attention( |
| 1007 | + q_nope, q_pe, kv_c_and_k_pe_cache[0], |
| 1008 | + kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table, |
| 1009 | + attn_metadata.decode.seq_lens, self.num_heads, self.scale, |
| 1010 | + self.num_kv_heads) |
| 1011 | + else: |
| 1012 | + q = torch.cat([q_nope, q_pe], dim=-1) |
| 1013 | + attn_output = torch.empty( |
| 1014 | + [num_tokens, self.num_heads, self.kv_lora_rank], |
| 1015 | + dtype=q.dtype, |
| 1016 | + device=q.device) |
| 1017 | + k_cache = torch.cat( |
| 1018 | + [kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) |
| 1019 | + torch_npu._npu_paged_attention_mla( |
| 1020 | + query=q, |
| 1021 | + key_cache=k_cache, |
| 1022 | + num_kv_heads=self.num_kv_heads, |
| 1023 | + num_heads=self.num_heads, |
| 1024 | + scale_value=self.scale, |
| 1025 | + block_table=attn_metadata.decode. |
| 1026 | + block_table, # type:ignore |
| 1027 | + context_lens=attn_metadata.decode.seq_lens, # type:ignore |
| 1028 | + mla_vheadsize=self.kv_lora_rank, |
| 1029 | + out=attn_output) |
1016 | 1030 | current_ms_metadata = get_multistream_comm_context() |
1017 | 1031 | if current_ms_metadata is None: |
1018 | 1032 | return self._v_up_proj_and_o_proj(attn_output) |
@@ -1193,10 +1207,11 @@ def forward( |
1193 | 1207 | decode_k_nope, decode_k_pe, |
1194 | 1208 | kv_cache, attn_metadata) |
1195 | 1209 | else: |
1196 | | - combined_cache = torch.cat([kv_cache[0], kv_cache[1]], dim=-1) |
1197 | | - output_decode = self._forward_decode( |
1198 | | - decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, |
1199 | | - combined_cache, attn_metadata) |
| 1210 | + output_decode = self._forward_decode(decode_ql_nope, |
| 1211 | + decode_q_pe, |
| 1212 | + decode_k_nope, |
| 1213 | + decode_k_pe, kv_cache, |
| 1214 | + attn_metadata) |
1200 | 1215 | current_ms_metadata = get_multistream_comm_context() |
1201 | 1216 | if current_ms_metadata is not None: |
1202 | 1217 | with torch.npu.stream(current_ms_metadata.comm_stream): |
|
0 commit comments