@@ -59,6 +59,7 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
5959 paged_kv_last_page_len : Optional [torch .Tensor ] = None
6060 # The query indptr, shape : [num_decode + 1]
6161 qo_indptr : Optional [torch .Tensor ] = None
62+ max_seqlen_qo : Optional [int ] = None
6263
6364 num_kv_splits_indptr : Optional [torch .Tensor ] = None
6465 batch_split_table : Optional [torch .Tensor ] = None
@@ -79,7 +80,7 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
7980class AiterMLAMetadataBuilder (MLACommonMetadataBuilder [AiterMLAMetadata ]):
8081 # TODO(luka, lucas): audit this as part of:
8182 # https://github.com/vllm-project/vllm/issues/22945
82- #cudagraph_support: ClassVar[AttentionCGSupport] = \
83+ # cudagraph_support: ClassVar[AttentionCGSupport] = \
8384 # AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
8485 cudagraph_support : ClassVar [AttentionCGSupport ] = \
8586 AttentionCGSupport .UNIFORM_BATCH
@@ -98,7 +99,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
9899 max_num_reqs = vllm_config .scheduler_config .max_num_seqs
99100 max_num_pages = max_num_reqs * max_num_pages_per_req
100101
101-
102102 self .speculative_config = vllm_config .speculative_config
103103
104104 if self .speculative_config :
@@ -209,7 +209,6 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
209209 max_seqlen_qo = 1
210210 num_kv_splits_indptr = None
211211
212-
213212 # max_seqlen_qo should be set according to the MTP. For example, MTP1 corresponds to max_seqlen_qo=2.
214213 speculative_config = self .vllm_config .speculative_config
215214 if speculative_config is not None and speculative_config .num_speculative_tokens is not None :
@@ -243,10 +242,10 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
243242 topk = - 1 ,
244243 )
245244
246-
247245 attn_metadata = AiterMLADecodeMetadata (
248246 block_table = block_table_tensor ,
249247 seq_lens = seq_lens_device ,
248+ max_seqlen_qo = max_seqlen_qo ,
250249 paged_kv_indptr = paged_kv_indptr ,
251250 paged_kv_indices = paged_kv_indices ,
252251 paged_kv_last_page_len = paged_kv_last_page_len ,
@@ -257,8 +256,8 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
257256 reduce_indptr = self .reduce_indptr ,
258257 reduce_final_map = self .reduce_final_map ,
259258 reduce_partial_map = self .reduce_partial_map ,
260- qo_indptr = qo_indptr )
261-
259+ qo_indptr = qo_indptr ,
260+ )
262261
263262 return attn_metadata
264263
@@ -339,30 +338,34 @@ def _forward_decode(
339338
340339 # max_seqlen_qo must be 1 except for MTP
341340 # TODO: Find the best value for MTP
342- max_seqlen_qo = 2
343341
344342 q_scale_input = None
345343 if hasattr (layer , '_q_scale_float' ) and layer ._q_scale_float != 1.0 :
346344 q_scale_input = torch .tensor ([layer ._q_scale_float ], dtype = torch .float32 , device = q .device )
347345 elif hasattr (layer , '_q_scale' ):
348346 q_scale_input = layer ._q_scale .to (q .device )
349-
350- aiter_mla_decode_fwd (q , kv_buffer , o ,
351- attn_metadata .decode .qo_indptr ,
352- attn_metadata .decode .paged_kv_indptr ,
353- attn_metadata .decode .paged_kv_indices ,
354- attn_metadata .decode .paged_kv_last_page_len ,
355- max_seqlen_qo , self .scale ,
356- True , 0.0 , 1 ,
357- attn_metadata .decode .num_kv_splits_indptr ,
358- attn_metadata .decode .work_metadata ,
359- attn_metadata .decode .work_indptr ,
360- attn_metadata .decode .work_info_set ,
361- attn_metadata .decode .reduce_indptr ,
362- attn_metadata .decode .reduce_final_map ,
363- attn_metadata .decode .reduce_partial_map ,
364- q_scale_input ,
365- )
366347
367- return o , None
348+ aiter_mla_decode_fwd (
349+ q ,
350+ kv_buffer ,
351+ o ,
352+ attn_metadata .decode .qo_indptr ,
353+ attn_metadata .decode .paged_kv_indptr ,
354+ attn_metadata .decode .paged_kv_indices ,
355+ attn_metadata .decode .paged_kv_last_page_len ,
356+ attn_metadata .decode .max_seqlen_qo ,
357+ self .scale ,
358+ True ,
359+ 0.0 ,
360+ 1 ,
361+ attn_metadata .decode .num_kv_splits_indptr ,
362+ attn_metadata .decode .work_metadata ,
363+ attn_metadata .decode .work_indptr ,
364+ attn_metadata .decode .work_info_set ,
365+ attn_metadata .decode .reduce_indptr ,
366+ attn_metadata .decode .reduce_final_map ,
367+ attn_metadata .decode .reduce_partial_map ,
368+ q_scale_input ,
369+ )
368370
371+ return o , None
0 commit comments