Skip to content

Commit c8cd500

Browse files
committed
merge conflicts
Signed-off-by: chenwaner <861645847@qq.com>
1 parent 906f651 commit c8cd500

File tree

3 files changed

+40
-14
lines changed

3 files changed

+40
-14
lines changed

docs/source/user_guide/additional_config.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ The details of each config option are as follows:
4040
| `use_cached_graph` | bool | `False` | Whether to use cached graph |
4141
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
4242
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
43+
| `enable_multistream_shared_expert`| bool | `False` | Whether to enable multistream shared expert |
4344
| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout |
4445

4546
**ascend_scheduler_config**
@@ -60,7 +61,8 @@ A full example of additional configuration is as follows:
6061
"enabled": true,
6162
"use_cached_graph": true,
6263
"graph_batch_sizes": [1, 2, 4, 8],
63-
"graph_batch_sizes_init": true,
64+
"graph_batch_sizes_init": false,
65+
"enable_multistream_shared_expert": false,
6466
"enable_kv_nz": false
6567
},
6668
"ascend_scheduler_config": {

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def __init__(self, torchair_graph_config):
5555
"graph_batch_sizes_init", False)
5656
self.enable_multistream_shared_expert = torchair_graph_config.get(
5757
"enable_multistream_shared_expert", False)
58+
self.enable_view_optimize = torchair_graph_config.get(
59+
"enable_view_optimize", True)
5860
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
5961

6062
if not isinstance(self.graph_batch_sizes, list):

vllm_ascend/attention/mla_v1.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
from vllm_ascend.ascend_config import get_ascend_config
1515
from vllm_ascend.attention.attention_v1 import AscendAttentionState
16-
import vllm_ascend.envs as envs_ascend
16+
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
17+
from vllm_ascend.multistream.context import get_multistream_comm_context
18+
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
1719
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
1820

1921
if TYPE_CHECKING:
@@ -444,9 +446,14 @@ def __init__(
444446
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
445447
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
446448

447-
self.enable_kv_nz = envs_ascend.VLLM_ENABLE_KV_NZ
448449
ascend_config = get_ascend_config()
449450
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
451+
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
452+
# Adapt torch air graph mode with spec decoding.
453+
speculative_config = get_current_vllm_config().speculative_config
454+
if speculative_config is not None:
455+
self.spec_token_num = speculative_config.num_speculative_tokens
456+
assert self.spec_token_num > 0
450457

451458
def _v_up_proj_and_o_proj(self, x):
452459
# Convert from (B, N, L) to (N, B, L)
@@ -679,24 +686,38 @@ def _forward_decode(
679686
dtype=q.dtype,
680687
device=q.device)
681688
if self.running_in_graph:
689+
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
690+
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
691+
assert num_tokens % self.spec_token_num == 0
692+
q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1),
693+
self.spec_token_num + 1, self.num_heads,
694+
-1)
695+
q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1),
696+
self.spec_token_num + 1, self.num_heads, -1)
697+
if not self.enable_kv_nz:
698+
q_nope = q_nope.transpose(1, 2).contiguous()
699+
q_pe = q_pe.transpose(1, 2).contiguous()
700+
sparse_mode = 3
701+
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
702+
else:
703+
if self.enable_kv_nz:
704+
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
705+
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
706+
else:
707+
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
708+
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
709+
sparse_mode = 0
710+
spec_attn_mask = None
711+
# shape of knope/k_pe for npu graph mode should be:
712+
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
682713
block_size = kv_c_and_k_pe_cache[0].shape[1]
683714
if self.enable_kv_nz:
684-
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
685-
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
686-
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
687-
# shape of knope/k_pe for npu graph mode should be:
688-
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
689715
k_nope = k_nope.view(-1, self.num_kv_heads,
690716
self.kv_lora_rank // 16, block_size, 16)
691717
k_pe = k_pe.view(-1, self.num_kv_heads,
692718
self.qk_rope_head_dim // 16, block_size, 16)
693719
input_layout = "BSND"
694720
else:
695-
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
696-
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
697-
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
698-
# shape of knope/k_pe for npu graph mode should be:
699-
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
700721
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
701722
self.kv_lora_rank)
702723
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
@@ -712,7 +733,8 @@ def _forward_decode(
712733
num_heads=self.num_heads,
713734
num_key_value_heads=self.num_kv_heads,
714735
input_layout=input_layout,
715-
atten_mask=attn_metadata.attn_mask,
736+
atten_mask=spec_attn_mask,
737+
sparse_mode=sparse_mode,
716738
scale=self.scale,
717739
antiquant_mode=0,
718740
antiquant_scale=None,

0 commit comments

Comments
 (0)