@@ -170,7 +170,7 @@ def __init__(self,
170170 if metadata_cls is not None else AscendMLAMetadata # type: ignore
171171 self .runner = runner
172172 scheduler_config = runner .scheduler_config
173- self .chunked_prefill_enabled = scheduler_config .chunked_prefill_enabled
173+ self .chunked_prefill_enabled = scheduler_config .chunked_prefill_enabled
174174 ascend_config = get_ascend_config ()
175175 self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
176176
@@ -477,13 +477,7 @@ def __init__(
477477 self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
478478 # Adapt torch air graph mode with spec decoding.
479479 speculative_config = get_current_vllm_config ().speculative_config
480- self .fia_sparse_mode = 0
481- self .use_spec_decode = False
482- # We need to set the sparse_mode of fused_infer_attention op to 3
483- # in spec decoding scenario in order to pass in attention mask.
484480 if speculative_config is not None :
485- self .fia_sparse_mode = 3
486- self .use_spec_decode = True
487481 self .spec_token_num = speculative_config .num_speculative_tokens
488482 assert self .spec_token_num > 0
489483
@@ -575,7 +569,10 @@ def _forward_prefill(
575569 num_tokens = query .size (0 )
576570 attn_output = None
577571 # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
578- if attn_metadata .attn_state == AscendAttentionState .ChunkedPrefill :
572+ if attn_metadata .attn_state in [
573+ AscendAttentionState .ChunkedPrefill ,
574+ AscendAttentionState .SpecDecoding
575+ ]:
579576 attn_output = torch .empty (num_tokens ,
580577 self .num_heads * self .v_head_dim ,
581578 dtype = query .dtype ,
@@ -622,7 +619,7 @@ def _forward_prefill(
622619 attn_output = attn_output .view (- 1 , self .num_heads , self .v_head_dim )
623620 else :
624621 raise RuntimeError (
625- "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
622+ "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
626623 )
627624 attn_output = attn_output .reshape (
628625 [num_tokens , self .num_heads * self .v_head_dim ])
@@ -696,7 +693,7 @@ def _forward_decode(
696693 device = q .device )
697694 if self .running_in_graph :
698695 # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
699- if self . use_spec_decode :
696+ if attn_metadata . attn_state == AscendAttentionState . SpecDecoding :
700697 assert num_tokens % self .spec_token_num == 0
701698 q_nope = (q_nope .view (
702699 num_tokens // (self .spec_token_num + 1 ),
@@ -710,9 +707,13 @@ def _forward_decode(
710707 self .num_heads ,
711708 - 1 ,
712709 ).transpose (1 , 2 ).contiguous ())
710+ sparse_mode = 3
711+ spec_attn_mask = attn_metadata .decode .attn_mask # type:ignore
713712 else :
714713 q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
715714 q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
715+ sparse_mode = 0
716+ spec_attn_mask = None
716717 # shape of knope/k_pe for npu graph mode should be:
717718 # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
718719 block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
@@ -730,8 +731,8 @@ def _forward_decode(
730731 num_heads = self .num_heads ,
731732 num_key_value_heads = self .num_kv_heads ,
732733 input_layout = "BNSD" ,
733- atten_mask = attn_metadata . decode . attn_mask , # type:ignore
734- sparse_mode = self . fia_sparse_mode ,
734+ atten_mask = spec_attn_mask ,
735+ sparse_mode = sparse_mode ,
735736 scale = self .scale ,
736737 antiquant_mode = 0 ,
737738 antiquant_scale = None ,
@@ -773,7 +774,9 @@ def forward(
773774 if attn_metadata is None :
774775 # Profiling run.
775776 return output
776- self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state == AscendAttentionState .DecodeOnly
777+ self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
778+ AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
779+ ]
777780 num_actual_toks = attn_metadata .num_actual_tokens
778781 if k_pe is None and not self .running_in_graph :
779782 kv_c , k_pe = self .kv_a_proj_with_mqa (
0 commit comments