Skip to content

Commit 7258e1d

Browse files
valarLipDuyi-Wang
authored andcommitted
fix max_seqlen_qo
1 parent 0ab29bb commit 7258e1d

File tree

1 file changed

+28
-25
lines changed

1 file changed

+28
-25
lines changed

vllm/v1/attention/backends/mla/rocm_aiter_mla.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
5959
paged_kv_last_page_len: Optional[torch.Tensor] = None
6060
# The query indptr, shape : [num_decode + 1]
6161
qo_indptr: Optional[torch.Tensor] = None
62+
max_seqlen_qo: Optional[int] = None
6263

6364
num_kv_splits_indptr: Optional[torch.Tensor] = None
6465
batch_split_table: Optional[torch.Tensor] = None
@@ -79,7 +80,7 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
7980
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
8081
# TODO(luka, lucas): audit this as part of:
8182
# https://github.com/vllm-project/vllm/issues/22945
82-
#cudagraph_support: ClassVar[AttentionCGSupport] = \
83+
# cudagraph_support: ClassVar[AttentionCGSupport] = \
8384
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
8485
cudagraph_support: ClassVar[AttentionCGSupport] = \
8586
AttentionCGSupport.UNIFORM_BATCH
@@ -98,7 +99,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
9899
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
99100
max_num_pages = max_num_reqs * max_num_pages_per_req
100101

101-
102102
self.speculative_config = vllm_config.speculative_config
103103

104104
if self.speculative_config:
@@ -209,7 +209,6 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
209209
max_seqlen_qo = 1
210210
num_kv_splits_indptr = None
211211

212-
213212
# max_seqlen_qo should be set according to the MTP. For example, MTP1 corresponds to max_seqlen_qo=2.
214213
speculative_config = self.vllm_config.speculative_config
215214
if speculative_config is not None and speculative_config.num_speculative_tokens is not None:
@@ -243,10 +242,10 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
243242
topk=-1,
244243
)
245244

246-
247245
attn_metadata = AiterMLADecodeMetadata(
248246
block_table=block_table_tensor,
249247
seq_lens=seq_lens_device,
248+
max_seqlen_qo=max_seqlen_qo,
250249
paged_kv_indptr=paged_kv_indptr,
251250
paged_kv_indices=paged_kv_indices,
252251
paged_kv_last_page_len=paged_kv_last_page_len,
@@ -257,8 +256,8 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
257256
reduce_indptr=self.reduce_indptr,
258257
reduce_final_map=self.reduce_final_map,
259258
reduce_partial_map=self.reduce_partial_map,
260-
qo_indptr=qo_indptr)
261-
259+
qo_indptr=qo_indptr,
260+
)
262261

263262
return attn_metadata
264263

@@ -339,30 +338,34 @@ def _forward_decode(
339338

340339
# max_seqlen_qo must be 1 except for MTP
341340
# TODO: Find the best value for MTP
342-
max_seqlen_qo = 2
343341

344342
q_scale_input = None
345343
if hasattr(layer, '_q_scale_float') and layer._q_scale_float != 1.0:
346344
q_scale_input = torch.tensor([layer._q_scale_float], dtype=torch.float32, device=q.device)
347345
elif hasattr(layer, '_q_scale'):
348346
q_scale_input = layer._q_scale.to(q.device)
349-
350-
aiter_mla_decode_fwd(q, kv_buffer, o,
351-
attn_metadata.decode.qo_indptr,
352-
attn_metadata.decode.paged_kv_indptr,
353-
attn_metadata.decode.paged_kv_indices,
354-
attn_metadata.decode.paged_kv_last_page_len,
355-
max_seqlen_qo, self.scale,
356-
True, 0.0, 1,
357-
attn_metadata.decode.num_kv_splits_indptr,
358-
attn_metadata.decode.work_metadata,
359-
attn_metadata.decode.work_indptr,
360-
attn_metadata.decode.work_info_set,
361-
attn_metadata.decode.reduce_indptr,
362-
attn_metadata.decode.reduce_final_map,
363-
attn_metadata.decode.reduce_partial_map,
364-
q_scale_input,
365-
)
366347

367-
return o, None
348+
aiter_mla_decode_fwd(
349+
q,
350+
kv_buffer,
351+
o,
352+
attn_metadata.decode.qo_indptr,
353+
attn_metadata.decode.paged_kv_indptr,
354+
attn_metadata.decode.paged_kv_indices,
355+
attn_metadata.decode.paged_kv_last_page_len,
356+
attn_metadata.decode.max_seqlen_qo,
357+
self.scale,
358+
True,
359+
0.0,
360+
1,
361+
attn_metadata.decode.num_kv_splits_indptr,
362+
attn_metadata.decode.work_metadata,
363+
attn_metadata.decode.work_indptr,
364+
attn_metadata.decode.work_info_set,
365+
attn_metadata.decode.reduce_indptr,
366+
attn_metadata.decode.reduce_final_map,
367+
attn_metadata.decode.reduce_partial_map,
368+
q_scale_input,
369+
)
368370

371+
return o, None

0 commit comments

Comments
 (0)