88 AttentionMetadata ,
99 MLAAttentionImpl )
1010from vllm .attention .backends .utils import PAD_SLOT_ID
11+ from vllm .config import get_current_vllm_config
1112from vllm .model_executor .layers .linear import (LinearBase ,
1213 UnquantizedLinearMethod )
1314
@@ -86,6 +87,7 @@ class AscendMLADecodeMetadata:
8687 seq_lens : torch .Tensor
8788 max_seq_lens : int
8889 seq_lens_list : list [int ]
90+ attn_mask : Optional [torch .Tensor ] = None
8991
9092
9193@dataclass
@@ -169,6 +171,8 @@ def __init__(self,
169171 self .runner = runner
170172 scheduler_config = runner .scheduler_config
171173 self .chunked_prefill_enabled = scheduler_config .chunked_prefill_enabled
174+ ascend_config = get_ascend_config ()
175+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
172176
173177 def reorder_batch (self , input_batch : "InputBatch" ,
174178 scheduler_output : "SchedulerOutput" ) -> bool :
@@ -185,16 +189,24 @@ def reorder_batch(self, input_batch: "InputBatch",
185189
186190 for i , req_id in enumerate (input_batch .req_ids ):
187191 num_tokens = scheduler_output .num_scheduled_tokens [req_id ]
188- # for now treat 1 scheduled token as "decode" even if its not,
189- # we should update this to something like < 8 in the future but
190- # currently the TritonMLA._forward_decode only supports
191- # num_tokens = 1
192- if num_tokens == 1 :
193- decodes .append (i )
194- num_decode_tokens += num_tokens
192+ num_spec_tokens = len (
193+ scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
194+ # For torch air graph mode we treat spec decoding as decode.
195+ if self .torchair_graph_enabled :
196+ if num_tokens - num_spec_tokens == 1 :
197+ decodes .append (i )
198+ num_decode_tokens += num_tokens
199+ else :
200+ prefills .append (i )
201+ num_prefill_tokens += num_tokens
202+ # For eager mode we treat spec decoding as chunked prefill.
195203 else :
196- prefills .append (i )
197- num_prefill_tokens += num_tokens
204+ if num_tokens == 1 :
205+ decodes .append (i )
206+ num_decode_tokens += num_tokens
207+ else :
208+ prefills .append (i )
209+ num_prefill_tokens += num_tokens
198210
199211 # We hope that this is fairly minimal since decodes
200212 # should be around for a number of iterations so hopefully they are
@@ -284,7 +296,8 @@ def build_dummy(self, num_reqs: int,
284296 block_table = block_table ,
285297 seq_lens = seq_lens ,
286298 seq_lens_list = seq_lens .tolist (),
287- max_seq_lens = 1 )
299+ max_seq_lens = 1 ,
300+ attn_mask = self .runner .spec_attn_mask )
288301 return self .metadata_cls ( # type: ignore
289302 num_input_tokens = num_actual_tokens ,
290303 num_actual_tokens = num_actual_tokens ,
@@ -332,7 +345,7 @@ def build(
332345 seq_lens = seq_lens_cpu
333346 max_query_len = query_lens .max ().item ()
334347 max_seq_lens = seq_lens .max ().item ()
335- query_start_loc = None
348+ query_start_loc = common_attn_metadata . query_start_loc
336349
337350 prefill_metadata = None
338351 if self ._num_prefills > 0 :
@@ -397,7 +410,8 @@ def build(
397410 block_table = block_table ,
398411 seq_lens = seq_lens ,
399412 seq_lens_list = seq_lens .tolist (),
400- max_seq_lens = max_seq_lens )
413+ max_seq_lens = max_seq_lens ,
414+ attn_mask = self .runner .spec_attn_mask )
401415
402416 return self .metadata_cls ( # type: ignore
403417 num_actual_tokens = num_actual_tokens ,
@@ -461,6 +475,11 @@ def __init__(
461475
462476 ascend_config = get_ascend_config ()
463477 self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
478+ # Adapt torch air graph mode with spec decoding.
479+ speculative_config = get_current_vllm_config ().speculative_config
480+ if speculative_config is not None :
481+ self .spec_token_num = speculative_config .num_speculative_tokens
482+ assert self .spec_token_num > 0
464483
465484 def _v_up_proj_and_o_proj (self , x ):
466485 # Convert from (B, N, L) to (N, B, L)
@@ -550,7 +569,10 @@ def _forward_prefill(
550569 num_tokens = query .size (0 )
551570 attn_output = None
552571 # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
553- if attn_metadata .attn_state == AscendAttentionState .ChunkedPrefill :
572+ if attn_metadata .attn_state in [
573+ AscendAttentionState .ChunkedPrefill ,
574+ AscendAttentionState .SpecDecoding
575+ ]:
554576 attn_output = torch .empty (num_tokens ,
555577 self .num_heads * self .v_head_dim ,
556578 dtype = query .dtype ,
@@ -597,7 +619,7 @@ def _forward_prefill(
597619 attn_output = attn_output .view (- 1 , self .num_heads , self .v_head_dim )
598620 else :
599621 raise RuntimeError (
600- "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
622+ "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
601623 )
602624 attn_output = attn_output .reshape (
603625 [num_tokens , self .num_heads * self .v_head_dim ])
@@ -670,9 +692,28 @@ def _forward_decode(
670692 dtype = q .dtype ,
671693 device = q .device )
672694 if self .running_in_graph :
673- # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
674- q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
675- q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
695+ # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
696+ if attn_metadata .attn_state == AscendAttentionState .SpecDecoding :
697+ assert num_tokens % self .spec_token_num == 0
698+ q_nope = (q_nope .view (
699+ num_tokens // (self .spec_token_num + 1 ),
700+ self .spec_token_num + 1 ,
701+ self .num_heads ,
702+ - 1 ,
703+ ).transpose (1 , 2 ).contiguous ())
704+ q_pe = (q_pe .view (
705+ num_tokens // (self .spec_token_num + 1 ),
706+ self .spec_token_num + 1 ,
707+ self .num_heads ,
708+ - 1 ,
709+ ).transpose (1 , 2 ).contiguous ())
710+ sparse_mode = 3
711+ spec_attn_mask = attn_metadata .decode .attn_mask # type:ignore
712+ else :
713+ q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
714+ q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
715+ sparse_mode = 0
716+ spec_attn_mask = None
676717 # shape of knope/k_pe for npu graph mode should be:
677718 # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
678719 block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
@@ -690,7 +731,8 @@ def _forward_decode(
690731 num_heads = self .num_heads ,
691732 num_key_value_heads = self .num_kv_heads ,
692733 input_layout = "BNSD" ,
693- atten_mask = attn_metadata .attn_mask ,
734+ atten_mask = spec_attn_mask ,
735+ sparse_mode = sparse_mode ,
694736 scale = self .scale ,
695737 antiquant_mode = 0 ,
696738 antiquant_scale = None ,
@@ -732,7 +774,9 @@ def forward(
732774 if attn_metadata is None :
733775 # Profiling run.
734776 return output
735- self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state == AscendAttentionState .DecodeOnly
777+ self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
778+ AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
779+ ]
736780 num_actual_toks = attn_metadata .num_actual_tokens
737781 if k_pe is None and not self .running_in_graph :
738782 kv_c , k_pe = self .kv_a_proj_with_mqa (
0 commit comments