@@ -85,6 +85,7 @@ class AscendMLADecodeMetadata:
8585 seq_lens : torch .Tensor
8686 max_seq_lens : int
8787 seq_lens_list : list [int ]
88+ attn_mask : torch .Tensor
8889
8990
9091@dataclass
@@ -170,11 +171,12 @@ def reorder_batch(self, input_batch: "InputBatch",
170171
171172 for i , req_id in enumerate (input_batch .req_ids ):
172173 num_tokens = scheduler_output .num_scheduled_tokens [req_id ]
174+ num_spec_tokens = len (scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
173175 # for now treat 1 scheduled token as "decode" even if its not,
174176 # we should update this to something like < 8 in the future but
175177 # currently the TritonMLA._forward_decode only supports
176178 # num_tokens = 1
177- if num_tokens == 1 :
179+ if num_tokens - num_spec_tokens == 1 :
178180 decodes .append (i )
179181 num_decode_tokens += num_tokens
180182 else :
@@ -270,7 +272,7 @@ def build(self,
270272 seq_lens = seq_lens_cpu
271273 max_query_len = query_lens .max ().item ()
272274 max_seq_lens = seq_lens .max ().item ()
273- query_start_loc = None
275+ query_start_loc = common_attn_metadata . query_start_loc
274276
275277 prefill_metadata = None
276278 if self ._num_prefills > 0 :
@@ -335,7 +337,8 @@ def build(self,
335337 block_table = block_table ,
336338 seq_lens = seq_lens ,
337339 seq_lens_list = seq_lens .tolist (),
338- max_seq_lens = max_seq_lens )
340+ max_seq_lens = max_seq_lens ,
341+ attn_mask = self .runner .spec_attn_mask )
339342
340343 return self .metadata_cls ( # type: ignore
341344 num_actual_tokens = num_actual_tokens ,
@@ -424,6 +427,17 @@ def __init__(
424427
425428 self .enable_graph_mode = False
426429 additional_config = get_current_vllm_config ().additional_config
430+ speculative_config = get_current_vllm_config ().speculative_config
431+ self .fia_sparse_mode = 0
432+ self .use_spec_decode = False
433+ # We need to set the sparse_mode of fused_infer_attention op to 3
434+ # in spec decoding scenario in order to pass in attention mask.
435+ if speculative_config is not None :
436+ self .fia_sparse_mode = 3
437+ self .use_spec_decode = True
438+ self .spec_token_num = speculative_config .num_speculative_tokens
439+ assert self .spec_token_num > 0
440+
427441 if additional_config :
428442 self .enable_graph_mode = additional_config .get (
429443 "enable_graph_mode" , False )
@@ -628,9 +642,32 @@ def _forward_decode(
628642 dtype = q .dtype ,
629643 device = q .device )
630644 if self .running_in_graph :
631- # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
632- q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
633- q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
645+ # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
646+ if self .use_spec_decode :
647+ assert num_tokens % self .spec_token_num == 0
648+ q_nope = (
649+ q_nope .view (
650+ num_tokens // self .spec_token_num ,
651+ self .spec_token_num ,
652+ self .num_heads ,
653+ - 1 ,
654+ )
655+ .transpose (1 , 2 )
656+ .contiguous ()
657+ )
658+ q_pe = (
659+ q_pe .view (
660+ num_tokens // self .spec_token_num ,
661+ self .spec_token_num ,
662+ self .num_heads ,
663+ - 1 ,
664+ )
665+ .transpose (1 , 2 )
666+ .contiguous ()
667+ )
668+ else :
669+ q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
670+ q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
634671 # shape of knope/k_pe for npu graph mode should be:
635672 # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
636673 block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
@@ -648,7 +685,8 @@ def _forward_decode(
648685 num_heads = self .num_heads ,
649686 num_key_value_heads = self .num_kv_heads ,
650687 input_layout = "BNSD" ,
651- atten_mask = attn_metadata .attn_mask ,
688+ atten_mask = attn_metadata .decode .attn_mask , # type:ignore
689+ sparse_mode = self .fia_sparse_mode ,
652690 scale = self .scale ,
653691 antiquant_mode = 0 ,
654692 antiquant_scale = None ,
0 commit comments