diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 4224d807c2b7..d5b30ac685ac 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -25,9 +25,9 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, - make_local_attention_virtual_batches) +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata, + get_kv_cache_layout) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -130,18 +130,6 @@ class FlashAttentionMetadata: prefix_scheduler_metadata: Optional[torch.Tensor] = None max_num_splits: int = 0 - # for local attention - @dataclass - class LocalAttentionMetadata: - local_query_start_loc: torch.Tensor - local_seqused_k: torch.Tensor - local_block_table: torch.Tensor - local_max_query_len: int - local_max_seq_len: int - local_scheduler_metadata: Optional[torch.Tensor] - - local_attn_metadata: Optional[LocalAttentionMetadata] = None - def _get_sliding_window_configs( vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: @@ -221,7 +209,6 @@ def build(self, max_query_len = common_attn_metadata.max_query_len max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) query_start_loc = common_attn_metadata.query_start_loc - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu block_table_tensor = common_attn_metadata.block_table_tensor @@ -266,40 +253,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, ) return None - # for local attention - local_attn_metadata = None - if self.model_config.attention_chunk_size is not None: - seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ - virt_block_table_tensor = make_local_attention_virtual_batches( - self.model_config.attention_chunk_size, - query_start_loc_cpu.numpy(), - seq_lens_cpu.numpy(), - block_table_tensor, - self.block_size, - ) - local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( - self.device, non_blocking=True) - local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( - self.device, non_blocking=True) - local_max_query_len = seqlens_q_local_np.max() - local_max_seq_len = virt_k_seqlens_np.max() - local_scheduler_metadata = schedule( - batch_size=local_query_start_loc.shape[0] - 1, - cu_query_lens=local_query_start_loc, - max_query_len=local_max_query_len, - seqlens=local_seqused_k, - max_seq_len=local_max_seq_len, - causal=True) - - local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( - local_query_start_loc=local_query_start_loc, - local_seqused_k=local_seqused_k, - local_block_table=virt_block_table_tensor, - local_max_query_len=local_max_query_len, - local_max_seq_len=local_max_seq_len, - local_scheduler_metadata=local_scheduler_metadata, - ) - use_cascade = common_prefix_len > 0 if use_cascade: @@ -371,7 +324,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, cu_prefix_query_lens=cu_prefix_query_lens, prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, - local_attn_metadata=local_attn_metadata, prefix_scheduler_metadata=prefix_scheduler_metadata, max_num_splits=max_num_splits, ) @@ -517,27 +469,13 @@ def forward( layer._q_scale) query = query.reshape((num_tokens, num_heads, head_size)) - # Compute attention and update output up to `num_actual_tokens`. - use_local_attn = \ - (self.use_irope and attn_metadata.local_attn_metadata is not None) - - if not attn_metadata.use_cascade or use_local_attn: - if use_local_attn: - assert attn_metadata.local_attn_metadata is not None - local_metadata = attn_metadata.local_attn_metadata - cu_seqlens_q = local_metadata.local_query_start_loc - seqused_k = local_metadata.local_seqused_k - max_seqlen_q = local_metadata.local_max_query_len - max_seqlen_k = local_metadata.local_max_seq_len - block_table = local_metadata.local_block_table - scheduler_metadata = local_metadata.local_scheduler_metadata - else: - cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len - block_table = attn_metadata.block_table - scheduler_metadata = attn_metadata.scheduler_metadata + if not attn_metadata.use_cascade: + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + scheduler_metadata = attn_metadata.scheduler_metadata descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) @@ -565,8 +503,6 @@ def forward( ) return output - assert not use_local_attn, ( - "Cascade attention does not support local attention.") # Cascade attention (rare case). cascade_attention( output[:num_actual_tokens], diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 1eb27d57acf0..8cf1aa86459d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -495,10 +495,6 @@ def __init__( kv_sharing_target_layer_name: Optional[int] = None, use_irope: bool = False, ) -> None: - if use_irope: - logger.warning_once( - "Using irope in FlashInfer is not supported yet, it will fall" - " back to global attention for long context.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -513,6 +509,7 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + self.use_irope = use_irope self.num_queries_per_kv = self.num_heads // self.num_kv_heads diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 46802bf5c2a9..43fe30a9a89f 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -13,8 +13,6 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.v1.attention.backends.flash_attn import ( - make_local_attention_virtual_batches) from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import AttentionSpec @@ -201,9 +199,7 @@ def build(self, max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) total_tokens = int(common_attn_metadata.seq_lens_cpu.sum()) query_start_loc = common_attn_metadata.query_start_loc - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens - seq_lens_cpu = common_attn_metadata.seq_lens_cpu block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping @@ -215,56 +211,6 @@ def build(self, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]) - def schedule(batch_size, cu_query_lens, max_query_len, seqlens, - max_seq_len, causal): - return None - - # for local attention - local_attn_metadata = None - if self.model_config.attention_chunk_size is not None: - seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ - virt_block_table_tensor = make_local_attention_virtual_batches( - self.model_config.attention_chunk_size, - query_start_loc_cpu.numpy(), - seq_lens_cpu.numpy(), - block_table_tensor, - self.block_size, - ) - local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( - self.device, non_blocking=True) - local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( - self.device, non_blocking=True) - local_max_query_len = seqlens_q_local_np.max().item() - local_max_seq_len = virt_k_seqlens_np.max().item() - local_scheduler_metadata = schedule( - batch_size=local_query_start_loc.shape[0] - 1, - cu_query_lens=local_query_start_loc, - max_query_len=local_max_query_len, - seqlens=local_seqused_k, - max_seq_len=local_max_seq_len, - causal=True) - - local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1, - dtype=torch.int32, - device=self.device) - local_cu_seq_lens[1:] = torch.cumsum( - torch.from_numpy(virt_k_seqlens_np).to(device=self.device, - dtype=torch.int32, - non_blocking=True), - dim=0) - - - local_attn_metadata = \ - AiterFlashAttentionMetadata.LocalAttentionMetadata( - local_query_start_loc=local_query_start_loc, - local_seqused_k=local_seqused_k, - local_block_table=virt_block_table_tensor, - local_max_query_len=local_max_query_len, - local_max_seq_len=local_max_seq_len, - local_cu_seq_lens=local_cu_seq_lens, - local_scheduler_metadata=local_scheduler_metadata, - ) - use_cascade = common_prefix_len > 0 cu_prefix_query_lens = None @@ -286,7 +232,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, cu_prefix_query_lens=cu_prefix_query_lens, prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, - local_attn_metadata=local_attn_metadata, ) return attn_metadata @@ -377,19 +322,6 @@ class AiterFlashAttentionMetadata: prefix_kv_lens: Optional[torch.Tensor] suffix_kv_lens: Optional[torch.Tensor] - # for local attention - @dataclass - class LocalAttentionMetadata: - local_query_start_loc: torch.Tensor - local_seqused_k: torch.Tensor - local_block_table: torch.Tensor - local_max_query_len: int - local_max_seq_len: int - local_cu_seq_lens: torch.Tensor - local_scheduler_metadata: Optional[torch.Tensor] - - local_attn_metadata: Optional[LocalAttentionMetadata] = None - class AiterFlashAttentionImpl(AttentionImpl): @@ -521,25 +453,12 @@ def forward( layer._q_scale) query = query.reshape((num_tokens, num_heads, head_size)) - # Compute attention and update output up to `num_actual_tokens`. - use_local_attn = \ - (self.use_irope and attn_metadata.local_attn_metadata is not None) - - if not attn_metadata.use_cascade or use_local_attn: - if use_local_attn: - assert attn_metadata.local_attn_metadata is not None - local_metadata = attn_metadata.local_attn_metadata - cu_seqlens_q = local_metadata.local_query_start_loc - seqused_k = local_metadata.local_seqused_k - max_seqlen_q = local_metadata.local_max_query_len - max_seqlen_k = local_metadata.local_max_seq_len - block_table = local_metadata.local_block_table - else: - cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len - block_table = attn_metadata.block_table + if not attn_metadata.use_cascade: + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table if max_seqlen_q > 1: cu_seq_lens = attn_metadata.cu_seq_lens @@ -557,9 +476,7 @@ def forward( alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, - cu_seqlens_k=(cu_seq_lens if not use_local_attn else - local_metadata.local_cu_seq_lens), - ) + cu_seqlens_k=cu_seq_lens) _, num_heads, head_size = query.shape _PARTITION_SIZE_ROCM = 256 diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index ee95b5af6e47..79796ac14928 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -18,9 +18,8 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - make_local_attention_virtual_batches) +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -55,18 +54,6 @@ class TritonAttentionMetadata: scheduler_metadata: Optional[torch.Tensor] = None prefix_scheduler_metadata: Optional[torch.Tensor] = None - # for local attention - @dataclass - class LocalAttentionMetadata: - local_query_start_loc: torch.Tensor - local_seqused_k: torch.Tensor - local_block_table: torch.Tensor - local_max_query_len: int - local_max_seq_len: int - local_scheduler_metadata: Optional[torch.Tensor] - - local_attn_metadata: Optional[LocalAttentionMetadata] = None - class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): @@ -111,34 +98,6 @@ def build(self, block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping - # for local attention - local_attn_metadata = None - if self.attention_chunk_size is not None: - seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ - virt_block_table_tensor = make_local_attention_virtual_batches( - self.attention_chunk_size, - common_attn_metadata.query_start_loc_cpu.numpy(), - common_attn_metadata.seq_lens_cpu.numpy(), - block_table_tensor, - self.block_size, - ) - local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( - self.device, non_blocking=True) - local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( - self.device, non_blocking=True) - local_max_query_len = seqlens_q_local_np.max().item() - local_max_seq_len = virt_k_seqlens_np.max().item() - - local_attn_metadata = TritonAttentionMetadata \ - .LocalAttentionMetadata( - local_query_start_loc=local_query_start_loc, - local_seqused_k=local_seqused_k, - local_block_table=virt_block_table_tensor, - local_max_query_len=local_max_query_len, - local_max_seq_len=local_max_seq_len, - local_scheduler_metadata=None, - ) - use_cascade = common_prefix_len > 0 if use_cascade: @@ -170,7 +129,6 @@ def build(self, cu_prefix_query_lens=cu_prefix_query_lens, prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, - local_attn_metadata=local_attn_metadata, prefix_scheduler_metadata=prefix_scheduler_metadata, ) return attn_metadata @@ -384,23 +342,11 @@ def forward( layer._q_scale) query = query.reshape((num_tokens, num_heads, head_size)) - use_local_attn = \ - (self.use_irope and attn_metadata.local_attn_metadata is not None) - - if use_local_attn: - assert attn_metadata.local_attn_metadata is not None - local_metadata = attn_metadata.local_attn_metadata - cu_seqlens_q = local_metadata.local_query_start_loc - seqused_k = local_metadata.local_seqused_k - max_seqlen_q = local_metadata.local_max_query_len - max_seqlen_k = local_metadata.local_max_seq_len - block_table = local_metadata.local_block_table - else: - cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len - block_table = attn_metadata.block_table + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table if use_prefill_decode_attn: # Compute attention and update output up to `num_actual_tokens`. diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index db6eaa558642..b6a06b17bca2 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -272,11 +272,14 @@ def infer_global_hyperparameters( # block_table_local : shape[local_virtual_batches, pages_per_local_batch] def make_local_attention_virtual_batches( attn_chunk_size: int, - query_start_loc_np: np.ndarray, - seq_lens_np: np.ndarray, - block_table: torch.Tensor, + common_attn_metadata: CommonAttentionMetadata, block_size: int = 0, -) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: +) -> CommonAttentionMetadata: + query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy() + seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy() + block_table = common_attn_metadata.block_table_tensor + device = common_attn_metadata.query_start_loc.device + q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] actual_batch_size = seq_lens_np.shape[0] @@ -339,6 +342,7 @@ def make_local_attention_virtual_batches( attn_chunk_size, dtype=np.int32) seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block + num_computed_tokens_local = seqlens_k_local - seqlens_q_local k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ (rarange * attn_chunk_size + \ @@ -380,8 +384,22 @@ def make_local_attention_virtual_batches( block_table_local = block_table[batch_indices, block_indices]\ .view(virtual_batches, -1) - return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \ - block_table_local + query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) + seq_lens_cpu = torch.from_numpy(seqlens_k_local) + + return CommonAttentionMetadata( + query_start_loc_cpu=query_start_loc_cpu, + query_start_loc=query_start_loc_cpu.to(device=device, + non_blocking=True), + seq_lens_cpu=seq_lens_cpu, + seq_lens=seq_lens_cpu.to(device=device, non_blocking=True), + num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), + num_reqs=len(seq_lens_cpu), + num_actual_tokens=common_attn_metadata.num_actual_tokens, + max_query_len=seqlens_q_local.max(), + block_table_tensor=block_table_local, + slot_mapping=common_attn_metadata.slot_mapping, + ) def split_decodes_and_prefills( diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 5b4718038076..1560406c9004 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -7,7 +7,8 @@ from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, +from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, + FullAttentionSpec, KVCacheSpec, MambaSpec, SlidingWindowSpec) from vllm.v1.request import Request @@ -256,8 +257,10 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, ) -> tuple[list[KVCacheBlock], ...]: - assert isinstance(kv_cache_spec, FullAttentionSpec), ( - "FullAttentionManager can only be used for full attention groups") + assert isinstance( + kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) + ), "FullAttentionManager can only be used for full attention " \ + "and chunked local attention groups" computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids))) max_num_blocks = max_length // kv_cache_spec.block_size @@ -432,6 +435,7 @@ def allocate_new_blocks(self, request_id: str, spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, + ChunkedLocalAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, MambaSpec: MambaManager, } diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 43456a987def..6726709955f7 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -125,6 +125,21 @@ def merge(cls, specs: list[Self]) -> Self: return merged_spec +@dataclass +class ChunkedLocalAttentionSpec(AttentionSpec): + attention_chunk_size: int + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_model_len = vllm_config.model_config.max_model_len + return cdiv(max_model_len, self.block_size) * self.page_size_bytes + + @property + def type_id(self) -> str: + return ( + f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}" + ) # noqa + + @dataclass class SlidingWindowSpec(AttentionSpec): sliding_window: int diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 29f519393e4a..fc7f25388810 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -44,11 +44,14 @@ GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, + make_local_attention_virtual_batches) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget -from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, - KVCacheConfig, KVCacheSpec, MambaSpec, +from vllm.v1.kv_cache_interface import (AttentionSpec, + ChunkedLocalAttentionSpec, + FullAttentionSpec, KVCacheConfig, + KVCacheSpec, MambaSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) @@ -705,6 +708,12 @@ def _prepare_inputs( spec_decode_common_attn_metadata is None: spec_decode_common_attn_metadata = common_attn_metadata + if isinstance(kv_cache_group_spec.kv_cache_spec, + ChunkedLocalAttentionSpec): + common_attn_metadata = make_local_attention_virtual_batches( + kv_cache_group_spec.kv_cache_spec.attention_chunk_size, + common_attn_metadata, self.cache_config.block_size) + # Prepare for cascade attention if enabled & beneficial. common_prefix_len = 0 builder = self.attn_metadata_builders[kv_cache_group_id] @@ -2589,6 +2598,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: # TODO: Support other attention modules, e.g., cross-attention if attn_module.attn_type == AttentionType.DECODER: + use_local_attention = (self.attention_chunk_size is not None + and attn_module.impl.use_irope) if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, @@ -2597,6 +2608,14 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, use_mla=use_mla) + elif use_local_attention: + kv_cache_spec[layer_name] = (ChunkedLocalAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + attention_chunk_size=self.attention_chunk_size, + use_mla=use_mla)) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size,