88 AttentionMetadata ,
99 MLAAttentionImpl )
1010from vllm .attention .backends .utils import PAD_SLOT_ID
11+ from vllm .config import get_current_vllm_config
1112from vllm .model_executor .layers .linear import (LinearBase ,
1213 UnquantizedLinearMethod )
1314
@@ -83,6 +84,7 @@ class AscendMLADecodeMetadata:
8384 seq_lens : torch .Tensor
8485 max_seq_lens : int
8586 seq_lens_list : list [int ]
87+ attn_mask : torch .Tensor
8688
8789
8890@dataclass
@@ -170,11 +172,13 @@ def reorder_batch(self, input_batch: "InputBatch",
170172
171173 for i , req_id in enumerate (input_batch .req_ids ):
172174 num_tokens = scheduler_output .num_scheduled_tokens [req_id ]
175+ num_spec_tokens = len (
176+ scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
173177 # for now treat 1 scheduled token as "decode" even if its not,
174178 # we should update this to something like < 8 in the future but
175179 # currently the TritonMLA._forward_decode only supports
176180 # num_tokens = 1
177- if num_tokens == 1 :
181+ if num_tokens - num_spec_tokens == 1 :
178182 decodes .append (i )
179183 num_decode_tokens += num_tokens
180184 else :
@@ -269,7 +273,8 @@ def build_dummy(self, num_reqs: int,
269273 block_table = block_table ,
270274 seq_lens = seq_lens ,
271275 seq_lens_list = seq_lens .tolist (),
272- max_seq_lens = 1 )
276+ max_seq_lens = 1 ,
277+ attn_mask = self .runner .spec_attn_mask )
273278 return self .metadata_cls ( # type: ignore
274279 num_input_tokens = num_actual_tokens ,
275280 num_actual_tokens = num_actual_tokens ,
@@ -317,7 +322,7 @@ def build(
317322 seq_lens = seq_lens_cpu
318323 max_query_len = query_lens .max ().item ()
319324 max_seq_lens = seq_lens .max ().item ()
320- query_start_loc = None
325+ query_start_loc = common_attn_metadata . query_start_loc
321326
322327 prefill_metadata = None
323328 if self ._num_prefills > 0 :
@@ -382,7 +387,8 @@ def build(
382387 block_table = block_table ,
383388 seq_lens = seq_lens ,
384389 seq_lens_list = seq_lens .tolist (),
385- max_seq_lens = max_seq_lens )
390+ max_seq_lens = max_seq_lens ,
391+ attn_mask = self .runner .spec_attn_mask )
386392
387393 return self .metadata_cls ( # type: ignore
388394 num_actual_tokens = num_actual_tokens ,
@@ -445,6 +451,17 @@ def __init__(
445451
446452 ascend_config = get_ascend_config ()
447453 self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
454+ # Adapt torch air graph mode with spec decoding.
455+ speculative_config = get_current_vllm_config ().speculative_config
456+ self .fia_sparse_mode = 0
457+ self .use_spec_decode = False
458+ # We need to set the sparse_mode of fused_infer_attention op to 3
459+ # in spec decoding scenario in order to pass in attention mask.
460+ if speculative_config is not None :
461+ self .fia_sparse_mode = 3
462+ self .use_spec_decode = True
463+ self .spec_token_num = speculative_config .num_speculative_tokens
464+ assert self .spec_token_num > 0
448465
449466 def _v_up_proj_and_o_proj (self , x ):
450467 # Convert from (B, N, L) to (N, B, L)
@@ -646,9 +663,24 @@ def _forward_decode(
646663 dtype = q .dtype ,
647664 device = q .device )
648665 if self .running_in_graph :
649- # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
650- q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
651- q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
666+ # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
667+ if self .use_spec_decode :
668+ assert num_tokens % self .spec_token_num == 0
669+ q_nope = (q_nope .view (
670+ num_tokens // (self .spec_token_num + 1 ),
671+ self .spec_token_num + 1 ,
672+ self .num_heads ,
673+ - 1 ,
674+ ).transpose (1 , 2 ).contiguous ())
675+ q_pe = (q_pe .view (
676+ num_tokens // (self .spec_token_num + 1 ),
677+ self .spec_token_num + 1 ,
678+ self .num_heads ,
679+ - 1 ,
680+ ).transpose (1 , 2 ).contiguous ())
681+ else :
682+ q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
683+ q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
652684 # shape of knope/k_pe for npu graph mode should be:
653685 # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
654686 block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
@@ -666,7 +698,8 @@ def _forward_decode(
666698 num_heads = self .num_heads ,
667699 num_key_value_heads = self .num_kv_heads ,
668700 input_layout = "BNSD" ,
669- atten_mask = attn_metadata .attn_mask ,
701+ atten_mask = attn_metadata .decode .attn_mask , # type:ignore
702+ sparse_mode = self .fia_sparse_mode ,
670703 scale = self .scale ,
671704 antiquant_mode = 0 ,
672705 antiquant_scale = None ,
0 commit comments