1313
1414from vllm_ascend .ascend_config import get_ascend_config
1515from vllm_ascend .attention .attention_v1 import AscendAttentionState
16- import vllm_ascend .envs as envs_ascend
16+ from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
17+ from vllm_ascend .multistream .context import get_multistream_comm_context
18+ from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
1719from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
1820
1921if TYPE_CHECKING :
@@ -444,9 +446,14 @@ def __init__(
444446 self .kv_a_proj_with_mqa = kwargs .get ('kv_a_proj_with_mqa' , None )
445447 self .kv_a_layernorm = kwargs .get ('kv_a_layernorm' , None )
446448
447- self .enable_kv_nz = envs_ascend .VLLM_ENABLE_KV_NZ
448449 ascend_config = get_ascend_config ()
449450 self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
451+ self .enable_kv_nz = ascend_config .torchair_graph_config .enable_kv_nz
452+ # Adapt torch air graph mode with spec decoding.
453+ speculative_config = get_current_vllm_config ().speculative_config
454+ if speculative_config is not None :
455+ self .spec_token_num = speculative_config .num_speculative_tokens
456+ assert self .spec_token_num > 0
450457
451458 def _v_up_proj_and_o_proj (self , x ):
452459 # Convert from (B, N, L) to (N, B, L)
@@ -679,24 +686,38 @@ def _forward_decode(
679686 dtype = q .dtype ,
680687 device = q .device )
681688 if self .running_in_graph :
689+ # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
690+ if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
691+ assert num_tokens % self .spec_token_num == 0
692+ q_nope = q_nope .view (num_tokens // (self .spec_token_num + 1 ),
693+ self .spec_token_num + 1 , self .num_heads ,
694+ - 1 )
695+ q_pe = q_pe .view (num_tokens // (self .spec_token_num + 1 ),
696+ self .spec_token_num + 1 , self .num_heads , - 1 )
697+ if not self .enable_kv_nz :
698+ q_nope = q_nope .transpose (1 , 2 ).contiguous ()
699+ q_pe = q_pe .transpose (1 , 2 ).contiguous ()
700+ sparse_mode = 3
701+ spec_attn_mask = attn_metadata .decode .attn_mask # type:ignore
702+ else :
703+ if self .enable_kv_nz :
704+ q_nope = q_nope .view (num_tokens , 1 , self .num_heads , - 1 )
705+ q_pe = q_pe .view (num_tokens , 1 , self .num_heads , - 1 )
706+ else :
707+ q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
708+ q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
709+ sparse_mode = 0
710+ spec_attn_mask = None
711+ # shape of knope/k_pe for npu graph mode should be:
712+ # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
682713 block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
683714 if self .enable_kv_nz :
684- # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
685- q_nope = q_nope .view (num_tokens , 1 , self .num_heads , - 1 )
686- q_pe = q_pe .view (num_tokens , 1 , self .num_heads , - 1 )
687- # shape of knope/k_pe for npu graph mode should be:
688- # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
689715 k_nope = k_nope .view (- 1 , self .num_kv_heads ,
690716 self .kv_lora_rank // 16 , block_size , 16 )
691717 k_pe = k_pe .view (- 1 , self .num_kv_heads ,
692718 self .qk_rope_head_dim // 16 , block_size , 16 )
693719 input_layout = "BSND"
694720 else :
695- # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
696- q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
697- q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
698- # shape of knope/k_pe for npu graph mode should be:
699- # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
700721 k_nope = k_nope .view (- 1 , self .num_kv_heads , block_size ,
701722 self .kv_lora_rank )
702723 k_pe = k_pe .view (- 1 , self .num_kv_heads , block_size ,
@@ -712,7 +733,8 @@ def _forward_decode(
712733 num_heads = self .num_heads ,
713734 num_key_value_heads = self .num_kv_heads ,
714735 input_layout = input_layout ,
715- atten_mask = attn_metadata .attn_mask ,
736+ atten_mask = spec_attn_mask ,
737+ sparse_mode = sparse_mode ,
716738 scale = self .scale ,
717739 antiquant_mode = 0 ,
718740 antiquant_scale = None ,
0 commit comments