From ff33b3170217a8b1ef1d55cea86067d560ee522c Mon Sep 17 00:00:00 2001 From: ganyi Date: Tue, 14 Oct 2025 01:59:23 +0000 Subject: [PATCH 1/5] enable persistent mla kernel Signed-off-by: ganyi --- vllm/attention/ops/rocm_aiter_mla.py | 24 +++++++ .../attention/backends/mla/rocm_aiter_mla.py | 62 ++++++++++++++++++- 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 6308f63cc4e7..0ce9c403d100 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -33,6 +33,12 @@ def aiter_mla_decode_fwd( kv_indices: torch.Tensor | None = None, kv_last_page_lens: torch.Tensor | None = None, logit_cap: float = 0.0, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, ): torch.ops.vllm.rocm_aiter_mla_decode_fwd( q, @@ -45,6 +51,12 @@ def aiter_mla_decode_fwd( kv_last_page_lens, sm_scale=sm_scale, logit_cap=logit_cap, + work_meta_data = work_meta_data, + work_indptr = work_indptr, + work_info_set= work_info_set, + reduce_indptr = reduce_indptr, + reduce_final_map = reduce_final_map, + reduce_partial_map = reduce_partial_map, ) @@ -59,6 +71,12 @@ def mla_decode_fwd_impl( kv_last_page_lens: torch.Tensor | None = None, sm_scale: float = 1.0, logit_cap: float = 0.0, + work_meta_data: torch.Tensor = None, + work_indptr: torch.Tensor = None, + work_info_set: torch.Tensor = None, + reduce_indptr: torch.Tensor = None, + reduce_final_map: torch.Tensor = None, + reduce_partial_map: torch.Tensor = None, ) -> None: from aiter.mla import mla_decode_fwd @@ -73,6 +91,12 @@ def mla_decode_fwd_impl( max_seqlen_qo, sm_scale=sm_scale, logit_cap=logit_cap, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, ) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index d935c02243bd..23afd9cedb05 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar +from typing import ClassVar, Optional, Union +import math import torch @@ -56,6 +57,18 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): # The query indptr, shape : [num_decode + 1] qo_indptr: torch.Tensor | None = None + work_metadata: torch.Tensor | None = None + + work_info_set: torch.Tensor | None = None + + work_indptr: torch.Tensor | None = None + + reduce_indptr: torch.Tensor | None = None + + reduce_final_map: torch.Tensor | None = None + + reduce_partial_map: torch.Tensor | None = None + class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): pass @@ -82,6 +95,10 @@ def __init__( "AITER MLAonly supports block size 1." ) + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + self.compilation_config = vllm_config.compilation_config max_num_pages_per_req = cdiv( vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size @@ -89,6 +106,15 @@ def __init__( max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req + max_qo_len = 1 + max_qo_tiles_per_batch = int(math.ceil) + self.work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") + self.work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device='cuda') + self.work_info_set = torch.empty([max_num_reqs * max_qo_tiles_per_batch * cu_num, 8], dtype=torch.int32, device='cuda').fill_(-1) + self.reduce_indptr = torch.empty([max_num_reqs * max_qo_tiles_per_batch + 1], dtype=torch.int32, device='cuda') + self.reduce_final_map = torch.empty([max_num_reqs * max_qo_tiles_per_batch, 2], dtype=torch.int32, device='cuda') + self.reduce_partial_map = torch.empty([max_num_reqs * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device='cuda') + # Preparing persistent buffers # TODO: we can disambiguate between decode and mixed-prefill decode here # so we can only use the persistent buffer if a cudagraph is actually @@ -139,6 +165,28 @@ def _build_decode( block_table_bounds.cumsum(dim=0, dtype=torch.int32), ] ) + kv_indptr_cpu = torch.zeros([query_start_loc_cpu.size(0)], dtype=torch.int32) + torch.cumsum(seq_lens_cpu, dim=0, out=kv_indptr_cpu[1:]) + + import aiter + aiter.get_mla_metadata_v1( + query_start_loc_cpu, + kv_indptr_cpu, + self.num_heads // self.kv_cache_spec.num_kv_heads, + self.kv_cache_spec.num_kv_heads, + False, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=1, + uni_seqlen_qo=1, + fast_mode=False + ) + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): num_actual_pages = paged_kv_indices.size(0) @@ -176,6 +224,12 @@ def _build_decode( paged_kv_last_page_len=paged_kv_last_page_len, qo_indptr=qo_indptr, dcp_tot_seq_lens=dcp_tot_seq_lens_device, + work_metadata=self.work_metadata, + work_info_set=self.work_info_set, + work_indptr=self.work_indptr, + reduce_indptr=self.reduce_indptr, + reduce_final_map=self.reduce_final_map, + reduce_partial_map=self.reduce_partial_map, ) return attn_metadata @@ -274,6 +328,12 @@ def _forward_decode( attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len, + work_meta_data=attn_metadata.decode.work_metadata, + work_indptr=attn_metadata.decode.work_indptr, + work_info_set=attn_metadata.decode.work_info_set, + reduce_indptr=attn_metadata.decode.reduce_indptr, + reduce_final_map=attn_metadata.decode.reduce_final_map, + reduce_partial_map=attn_metadata.decode.reduce_partial_map, ) return o, None From e33bcff170cbeda5e8e156c18d0be1367dadc984 Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 16 Oct 2025 03:51:15 +0000 Subject: [PATCH 2/5] workable Signed-off-by: ganyi --- .../attention/backends/mla/rocm_aiter_mla.py | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 23afd9cedb05..eacaf396ff95 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -57,6 +57,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): # The query indptr, shape : [num_decode + 1] qo_indptr: torch.Tensor | None = None + max_seqlen_qo: int = 1 + work_metadata: torch.Tensor | None = None work_info_set: torch.Tensor | None = None @@ -106,8 +108,11 @@ def __init__( max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req - max_qo_len = 1 - max_qo_tiles_per_batch = int(math.ceil) + # num_mtp = vllm_config.speculative_config.num_speculative_tokens + # num_mtp = 1 if num_mtp is None else num_mtp + max_seqlen_qo = 1 if vllm_config.speculative_config is None else vllm_config.speculative_config.num_speculative_tokens + + max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * self.num_heads / 128)) self.work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") self.work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device='cuda') self.work_info_set = torch.empty([max_num_reqs * max_qo_tiles_per_batch * cu_num, 8], dtype=torch.int32, device='cuda').fill_(-1) @@ -165,16 +170,24 @@ def _build_decode( block_table_bounds.cumsum(dim=0, dtype=torch.int32), ] ) - kv_indptr_cpu = torch.zeros([query_start_loc_cpu.size(0)], dtype=torch.int32) - torch.cumsum(seq_lens_cpu, dim=0, out=kv_indptr_cpu[1:]) + kv_indptr = torch.zeros([query_start_loc_cpu.size(0)], dtype=torch.int32, device='cuda') + torch.cumsum(seq_lens_device, dim=0, out=kv_indptr[1:]) + query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + max_seqlen_qo = torch.max(query_lens).item() import aiter + print("the work_metadata: ", self.work_metadata, flush=True) + print("the work_info_set: ", self.work_info_set, flush=True) + print("the work_indptr: ", self.work_indptr, flush=True) + print("the input buffer: ", self.reduce_indptr, flush=True) + print("the input buffer: ", self.reduce_final_map, flush=True) + print("the input buffer: ", self.reduce_partial_map, flush=True) aiter.get_mla_metadata_v1( - query_start_loc_cpu, - kv_indptr_cpu, + query_start_loc_device, + kv_indptr, self.num_heads // self.kv_cache_spec.num_kv_heads, self.kv_cache_spec.num_kv_heads, - False, + True, self.work_metadata, self.work_info_set, self.work_indptr, @@ -182,9 +195,9 @@ def _build_decode( self.reduce_final_map, self.reduce_partial_map, kv_granularity=max(page_size, 16), - max_seqlen_qo=1, - uni_seqlen_qo=1, - fast_mode=False + max_seqlen_qo=max_seqlen_qo, + uni_seqlen_qo=max_seqlen_qo, + fast_mode=True ) @@ -224,6 +237,7 @@ def _build_decode( paged_kv_last_page_len=paged_kv_last_page_len, qo_indptr=qo_indptr, dcp_tot_seq_lens=dcp_tot_seq_lens_device, + max_seqlen_qo=max_seqlen_qo, work_metadata=self.work_metadata, work_info_set=self.work_info_set, work_indptr=self.work_indptr, @@ -317,14 +331,13 @@ def _forward_decode( # max_seqlen_qo must be 1 except for MTP # TODO: Find the best value for MTP - max_seqlen_qo = 1 aiter_mla_decode_fwd( q, kv_buffer, o, self.scale, attn_metadata.decode.qo_indptr, - max_seqlen_qo, + attn_metadata.decode.max_seqlen_qo, attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len, From f5bec923056e586977465e6cd9bb951c11bd0915 Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 16 Oct 2025 05:52:46 +0000 Subject: [PATCH 3/5] acc verified Signed-off-by: ganyi --- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index eacaf396ff95..502803a1da85 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -176,12 +176,6 @@ def _build_decode( max_seqlen_qo = torch.max(query_lens).item() import aiter - print("the work_metadata: ", self.work_metadata, flush=True) - print("the work_info_set: ", self.work_info_set, flush=True) - print("the work_indptr: ", self.work_indptr, flush=True) - print("the input buffer: ", self.reduce_indptr, flush=True) - print("the input buffer: ", self.reduce_final_map, flush=True) - print("the input buffer: ", self.reduce_partial_map, flush=True) aiter.get_mla_metadata_v1( query_start_loc_device, kv_indptr, @@ -194,7 +188,7 @@ def _build_decode( self.reduce_indptr, self.reduce_final_map, self.reduce_partial_map, - kv_granularity=max(page_size, 16), + kv_granularity=max(self.kv_cache_spec.block_size, 16), max_seqlen_qo=max_seqlen_qo, uni_seqlen_qo=max_seqlen_qo, fast_mode=True @@ -325,7 +319,7 @@ def _forward_decode( B = q.shape[0] o = torch.zeros( B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device - ) + ).fill_(-1) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) From c6f2fd29d522dec3aae8e84c9519c3d90e8c77a1 Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 16 Oct 2025 09:30:00 +0000 Subject: [PATCH 4/5] fp8 mla support Signed-off-by: ganyi --- vllm/attention/ops/rocm_aiter_mla.py | 28 +++++++++++++++---- .../attention/backends/mla/rocm_aiter_mla.py | 6 ++-- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 0ce9c403d100..41dab7976d1f 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -39,6 +39,8 @@ def aiter_mla_decode_fwd( reduce_indptr: torch.Tensor | None = None, reduce_final_map: torch.Tensor | None = None, reduce_partial_map: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None, ): torch.ops.vllm.rocm_aiter_mla_decode_fwd( q, @@ -57,6 +59,8 @@ def aiter_mla_decode_fwd( reduce_indptr = reduce_indptr, reduce_final_map = reduce_final_map, reduce_partial_map = reduce_partial_map, + q_scale = q_scale, + kv_scale = kv_scale, ) @@ -71,12 +75,14 @@ def mla_decode_fwd_impl( kv_last_page_lens: torch.Tensor | None = None, sm_scale: float = 1.0, logit_cap: float = 0.0, - work_meta_data: torch.Tensor = None, - work_indptr: torch.Tensor = None, - work_info_set: torch.Tensor = None, - reduce_indptr: torch.Tensor = None, - reduce_final_map: torch.Tensor = None, - reduce_partial_map: torch.Tensor = None, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None ) -> None: from aiter.mla import mla_decode_fwd @@ -97,6 +103,8 @@ def mla_decode_fwd_impl( reduce_indptr=reduce_indptr, reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, + q_scale=q_scale, + kv_scale=kv_scale, ) @@ -111,6 +119,14 @@ def mla_decode_fwd_fake( kv_last_page_lens: torch.Tensor | None = None, sm_scale: float = 1.0, logit_cap: float = 0.0, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None ) -> None: pass diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 502803a1da85..51687e043216 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -10,7 +10,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionLayer from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.utils import cdiv from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, @@ -318,7 +318,7 @@ def _forward_decode( assert isinstance(q, torch.Tensor) B = q.shape[0] o = torch.zeros( - B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device + B, self.num_heads, self.kv_lora_rank, dtype=torch.bfloat16, device=q.device ).fill_(-1) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) @@ -341,6 +341,8 @@ def _forward_decode( reduce_indptr=attn_metadata.decode.reduce_indptr, reduce_final_map=attn_metadata.decode.reduce_final_map, reduce_partial_map=attn_metadata.decode.reduce_partial_map, + q_scale=layer._q_scale, + kv_scale=layer._k_scale, ) return o, None From 5515fcbd8728f4b5c4dfe4a53c76866a5818a0b8 Mon Sep 17 00:00:00 2001 From: ganyi Date: Wed, 22 Oct 2025 05:26:45 +0000 Subject: [PATCH 5/5] lint fix Signed-off-by: ganyi --- vllm/attention/ops/rocm_aiter_mla.py | 20 ++++----- .../attention/backends/mla/rocm_aiter_mla.py | 42 ++++++++++++++----- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 41dab7976d1f..ce758b3f3013 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -53,14 +53,14 @@ def aiter_mla_decode_fwd( kv_last_page_lens, sm_scale=sm_scale, logit_cap=logit_cap, - work_meta_data = work_meta_data, - work_indptr = work_indptr, - work_info_set= work_info_set, - reduce_indptr = reduce_indptr, - reduce_final_map = reduce_final_map, - reduce_partial_map = reduce_partial_map, - q_scale = q_scale, - kv_scale = kv_scale, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=q_scale, + kv_scale=kv_scale, ) @@ -82,7 +82,7 @@ def mla_decode_fwd_impl( reduce_final_map: torch.Tensor | None = None, reduce_partial_map: torch.Tensor | None = None, q_scale: torch.Tensor | None = None, - kv_scale: torch.Tensor | None = None + kv_scale: torch.Tensor | None = None, ) -> None: from aiter.mla import mla_decode_fwd @@ -126,7 +126,7 @@ def mla_decode_fwd_fake( reduce_final_map: torch.Tensor | None = None, reduce_partial_map: torch.Tensor | None = None, q_scale: torch.Tensor | None = None, - kv_scale: torch.Tensor | None = None + kv_scale: torch.Tensor | None = None, ) -> None: pass diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 51687e043216..49e0e077206a 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -1,16 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math from dataclasses import dataclass from typing import ClassVar, Optional, Union -import math import torch import vllm.envs as envs from vllm.attention.backends.abstract import AttentionLayer from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config import VllmConfig from vllm.utils import cdiv from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, @@ -110,15 +110,33 @@ def __init__( # num_mtp = vllm_config.speculative_config.num_speculative_tokens # num_mtp = 1 if num_mtp is None else num_mtp - max_seqlen_qo = 1 if vllm_config.speculative_config is None else vllm_config.speculative_config.num_speculative_tokens + max_seqlen_qo = ( + 1 + if vllm_config.speculative_config is None + else vllm_config.speculative_config.num_speculative_tokens + ) max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * self.num_heads / 128)) self.work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") - self.work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device='cuda') - self.work_info_set = torch.empty([max_num_reqs * max_qo_tiles_per_batch * cu_num, 8], dtype=torch.int32, device='cuda').fill_(-1) - self.reduce_indptr = torch.empty([max_num_reqs * max_qo_tiles_per_batch + 1], dtype=torch.int32, device='cuda') - self.reduce_final_map = torch.empty([max_num_reqs * max_qo_tiles_per_batch, 2], dtype=torch.int32, device='cuda') - self.reduce_partial_map = torch.empty([max_num_reqs * max_qo_tiles_per_batch * cu_num], dtype=torch.int32, device='cuda') + self.work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") + self.work_info_set = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch * cu_num, 8], + dtype=torch.int32, + device="cuda", + ).fill_(-1) + self.reduce_indptr = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch + 1], + dtype=torch.int32, + device="cuda", + ) + self.reduce_final_map = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" + ) + self.reduce_partial_map = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch * cu_num], + dtype=torch.int32, + device="cuda", + ) # Preparing persistent buffers # TODO: we can disambiguate between decode and mixed-prefill decode here @@ -170,12 +188,15 @@ def _build_decode( block_table_bounds.cumsum(dim=0, dtype=torch.int32), ] ) - kv_indptr = torch.zeros([query_start_loc_cpu.size(0)], dtype=torch.int32, device='cuda') + kv_indptr = torch.zeros( + [query_start_loc_cpu.size(0)], dtype=torch.int32, device="cuda" + ) torch.cumsum(seq_lens_device, dim=0, out=kv_indptr[1:]) query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] max_seqlen_qo = torch.max(query_lens).item() import aiter + aiter.get_mla_metadata_v1( query_start_loc_device, kv_indptr, @@ -191,10 +212,9 @@ def _build_decode( kv_granularity=max(self.kv_cache_spec.block_size, 16), max_seqlen_qo=max_seqlen_qo, uni_seqlen_qo=max_seqlen_qo, - fast_mode=True + fast_mode=True, ) - if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): num_actual_pages = paged_kv_indices.size(0)