diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 05f6dd40a9ea..73b47f897439 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -313,7 +313,8 @@ def create_deterministic_logits(token_ids): # Mock runner for attention metadata building proposer.runner = mock.MagicMock() - proposer.runner.attn_metadata_builders = [attn_metadata_builder] + proposer.runner.attn_groups.append([mock.MagicMock()]) + proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder result = proposer.propose(target_token_ids=target_token_ids, target_positions=target_positions, diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 231dfcbb6884..e151d388c293 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -417,12 +417,12 @@ def rnd_stride_order(): return rnd_stride # Patch the attention backend class and re-trigger the KV cache creation. - for attn_backend in model_runner.attn_backends: + for attn_group in model_runner._attn_group_iterator(): + attn_backend = attn_group.backend monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", rnd_stride_order) - model_runner.attn_backends = [] - model_runner.attn_metadata_builders = [] + model_runner.attn_groups = [] model_runner.initialize_kv_cache(model_runner.kv_cache_config) # Shape is unchanged, but layout may differ diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index ba20da4fd75f..2417fe06a675 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -106,6 +106,10 @@ def advance_step(self, model_input: "ModelRunnerInputBase", block_size: int, num_seqs: int, num_queries: int) -> None: raise NotImplementedError + @classmethod + def full_cls_name(cls) -> tuple[str, str]: + return (cls.__module__, cls.__qualname__) + @dataclass class AttentionMetadata: diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 178453ecdc4e..b4c3cbd7c9d6 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -9,6 +9,7 @@ import vllm.envs as envs from vllm.attention import AttentionType +from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config @@ -80,6 +81,7 @@ def __init__( prefix: str = "", attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, + attn_backend: Optional[type[AttentionBackend]] = None, **extra_impl_args, ) -> None: """ @@ -137,15 +139,6 @@ def __init__( self.num_kv_heads = num_kv_heads self.sliding_window = sliding_window - # For v1 we have backend agnostic iRoPE (local chunked attention) - # we have to store the flag on the layer so gpu model runner can - # set KVSpec appropriately (and pop it so it doesnt get passed to - # the backends) - if envs.VLLM_USE_V1: - self.use_irope = extra_impl_args.pop("use_irope", False) - else: - self.use_irope = extra_impl_args.get("use_irope", False) - quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None if quant_method is not None and not isinstance( @@ -166,18 +159,22 @@ def __init__( # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(head_size, - dtype, - kv_cache_dtype, - block_size, - is_attention_free, - use_mla=use_mla) - impl_cls = attn_backend.get_impl_cls() + if attn_backend is None: + self.attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype, + block_size, + is_attention_free, + use_mla=use_mla) + else: + self.attn_backend = attn_backend + + impl_cls = self.attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, **extra_impl_args) - self.backend = backend_name_to_enum(attn_backend.get_name()) + self.backend = backend_name_to_enum(self.attn_backend.get_name()) self.dtype = dtype # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how @@ -187,7 +184,7 @@ def __init__( self.use_direct_call = not current_platform.is_cuda_alike( ) and not current_platform.is_cpu() - self.use_output = attn_backend.accept_output_buffer + self.use_output = self.attn_backend.accept_output_buffer compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") @@ -309,6 +306,9 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): if hasattr(self.impl, "process_weights_after_loading"): self.impl.process_weights_after_loading(act_dtype) + def get_attn_backend(self) -> type[AttentionBackend]: + return self.attn_backend + class MultiHeadAttention(nn.Module): """Multi-headed attention without any cache, used for ViT.""" diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py new file mode 100644 index 000000000000..892077ba91e0 --- /dev/null +++ b/vllm/attention/layers/chunked_local_attention.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from typing import List, Optional + +import torch + +from vllm import envs +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig, QuantizationConfig +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, make_local_attention_virtual_batches, + subclass_attention_backend, subclass_attention_metadata_builder) + +from ..layer import Attention + + +@functools.lru_cache +def create_chunked_local_attention_backend( + underlying_attn_backend: AttentionBackend, + attention_chunk_size: int, + block_size: int, +) -> type[AttentionBackend]: + prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" + + def build_preprocess_fn(cm: CommonAttentionMetadata): + return make_local_attention_virtual_batches(attention_chunk_size, cm, + block_size) + + # Dynamically create a new attention backend that wraps the + # underlying attention backend but applies + # `make_local_attention_virtual_batches` before calling `build(...)` + builder_cls = subclass_attention_metadata_builder( + name_prefix=prefix, + builder_cls=underlying_attn_backend.get_builder_cls(), + build_preprocess_fn=build_preprocess_fn) + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=builder_cls) + + return attn_backend + + +class ChunkedLocalAttention(Attention): + + def __init__(self, + num_heads: int, + head_size: int, + scale: float, + attention_chunk_size: int, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + kv_sharing_target_layer_name: Optional[str] = None, + prefix: str = ""): + dtype = torch.get_default_dtype() + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + if envs.VLLM_USE_V1: + underlying_attn_backend = get_attn_backend(head_size, dtype, + kv_cache_dtype, + block_size) + + attn_backend = create_chunked_local_attention_backend( + underlying_attn_backend, attention_chunk_size, block_size) + else: + # in v0 the local attention is handled inside the backends + attn_backend = None + + super().__init__( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + kv_sharing_target_layer_name=kv_sharing_target_layer_name, + attn_backend=attn_backend) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 596c556e54f0..508470bb363e 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -142,7 +142,7 @@ def get_attn_backend( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - is_attention_free: bool, + is_attention_free: bool = False, use_mla: bool = False, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 60098209c39a..1f8b9d074479 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -25,6 +25,7 @@ from transformers import Llama4TextConfig from vllm.attention import Attention +from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -194,17 +195,18 @@ def __init__(self, is_neox_style=is_neox_style, ) if not self.nope else None - self.attn = Attention( + attn_cls = Attention if self.nope else ChunkedLocalAttention + self.attn = attn_cls( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - per_layer_sliding_window=None, - use_irope=not self.nope, prefix=f"{prefix}.attn", - ) + **({ + "attention_chunk_size": config.attention_chunk_size + } if not self.nope else {})) def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: floor = torch.floor((positions + 1.0) / self.floor_scale) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 7aeea40b25a6..e6104c2ed7b2 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -5,12 +5,12 @@ import functools from abc import abstractmethod from dataclasses import dataclass, make_dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional, + TypeVar) import numpy as np import torch -from vllm.attention.layer import Attention from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.utils import cdiv @@ -20,6 +20,8 @@ from vllm.v1.worker.gpu_input_batch import InputBatch import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.layer import Attention from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout) from vllm.logger import init_logger @@ -522,6 +524,48 @@ def make_local_attention_virtual_batches( ) +def subclass_attention_metadata_builder( + name_prefix: str, + builder_cls: type[AttentionMetadataBuilder[M]], + build_preprocess_fn: Callable[[CommonAttentionMetadata], + CommonAttentionMetadata], +) -> type[AttentionMetadataBuilder[M]]: + """ + Return a new subclass of `builder_cls` whose .build(...) method + first calls build_preprocess_fn(common_attn_metadata) on the metadata. + """ + name: str = name_prefix + builder_cls.__name__ # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False): + return builder_cls.build(self, common_prefix_len, + build_preprocess_fn(common_attn_metadata), + fast_build) + + Wrapped = type( + name, + (builder_cls, ), # inherit from the original + { + "build": build, + }) + return Wrapped # type: ignore + + +def subclass_attention_backend( + name_prefix: str, attention_backend_cls: type[AttentionBackend], + builder_cls: type[AttentionMetadataBuilder[M]] +) -> type[AttentionBackend]: + """ + Return a new subclass where `get_builder_cls` returns `builder_cls`. + """ + name: str = name_prefix + attention_backend_cls.__name__ # type: ignore + + return type(name, (attention_backend_cls, ), + {"get_builder_cls": lambda: builder_cls}) + + def split_decodes_and_prefills( common_attn_metadata: CommonAttentionMetadata, decode_threshold: int = 1, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index b2380bb3dd5a..3c36971fe5b4 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -158,9 +158,9 @@ def propose( assert self.runner is not None # FIXME: need to consider multiple kv_cache_groups - attn_metadata = self.runner.attn_metadata_builders[ - 0].build_for_drafting(common_attn_metadata=common_attn_metadata, - draft_index=0) + attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ + .build_for_drafting(common_attn_metadata=common_attn_metadata, + draft_index=0) # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. @@ -349,7 +349,8 @@ def propose_tree( hidden_states: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, ) -> list[torch.Tensor]: - tree_attn_metadata_builder = self.runner.attn_metadata_builders[0] + tree_attn_metadata_builder = \ + self.runner.attn_groups[0][0].metadata_builder assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder) diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index d8f3e0d89a96..11b96d946365 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -53,11 +53,11 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: raise ValueError("Multiple KVCacheGroups is not" "currently supported with CPU model runner.") - assert type( - self.attn_metadata_builders[0]) is TorchSDPAMetadataBuilderV1 + assert type(self.attn_groups[0] + [0].metadata_builder) is TorchSDPAMetadataBuilderV1 - self.attn_metadata_builders[0].reorder_batch(self.input_batch, - scheduler_output) + self.attn_groups[0][0].metadata_builder.reorder_batch( + self.input_batch, scheduler_output) def _postprocess_tenosrs(self) -> None: # Note: replace device tensors with cpu tensors diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 85976fc1c825..76a29d5b643b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3,7 +3,10 @@ import dataclasses import gc +import itertools import time +from collections import defaultdict +from collections.abc import Iterator from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast @@ -14,9 +17,9 @@ from tqdm import tqdm import vllm.envs as envs -from vllm.attention import AttentionType, get_attn_backend +from vllm.attention import Attention, AttentionType from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.layer import Attention +from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.counter import compilation_counter from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config, update_config) @@ -50,7 +53,6 @@ from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, make_kv_sharing_fast_prefill_attention_metadata, - make_local_attention_virtual_batches, reorder_batch_to_split_decodes_and_prefills) from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, @@ -73,8 +75,8 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from ..sample.logits_processor import LogitsProcessorManager -from .utils import (MultiModalBudget, bind_kv_cache, gather_mm_placeholders, - initialize_kv_cache_for_kv_sharing, +from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache, + gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) if TYPE_CHECKING: @@ -162,8 +164,8 @@ def __init__( # self.model: nn.Module # Set after load_model # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] - self.attn_metadata_builders: list[AttentionMetadataBuilder] = [] - self.attn_backends: list[type[AttentionBackend]] = [] + # indexes: [kv_cache_group_id][attn_group] + self.attn_groups: list[list[AttentionGroup]] = [] # self.kv_cache_config: KVCacheConfig # req_id -> (input_id -> encoder_output) @@ -830,81 +832,51 @@ def _prepare_inputs( spec_decode_common_attn_metadata is None: spec_decode_common_attn_metadata = common_attn_metadata - if isinstance(kv_cache_group_spec.kv_cache_spec, - ChunkedLocalAttentionSpec): - common_attn_metadata = make_local_attention_virtual_batches( - kv_cache_group_spec.kv_cache_spec.attention_chunk_size, - common_attn_metadata, self.cache_config.block_size) - - # Prepare for cascade attention if enabled & beneficial. - common_prefix_len = 0 - builder = self.attn_metadata_builders[kv_cache_group_id] - if self.cascade_attn_enabled: - common_prefix_len = self._compute_cascade_attn_prefix_len( - num_scheduled_tokens, - scheduler_output. - num_common_prefix_blocks[kv_cache_group_id], - kv_cache_group_spec.kv_cache_spec, - builder, - ) - - attn_metadata_i = (builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - )) - - fast_prefill_metadata = attn_metadata_i - if (self.cache_config.kv_sharing_fast_prefill - and self.kv_sharing_fast_prefill_eligible_layers): - # Dynamically create a a dataclass type that inherits - # from attention metadata type but includes additional - # fields logits_indices_padded and num_logits_indices - # which are required for prefill truncation - fast_prefill_metadata_type = ( - make_kv_sharing_fast_prefill_attention_metadata( - metadata_cls=type(attn_metadata_i), )) - fast_prefill_metadata = fast_prefill_metadata_type( - **dataclasses.asdict(attn_metadata_i), - logits_indices_padded=logits_indices_padded, - num_logits_indices=logits_indices.size(0), - ) - - for layer_name in kv_cache_group_spec.layer_names: - if (self.cache_config.kv_sharing_fast_prefill and layer_name - in self.kv_sharing_fast_prefill_eligible_layers): - attn_metadata[layer_name] = fast_prefill_metadata - continue + for attn_group in self.attn_groups[kv_cache_group_id]: + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + builder = attn_group.metadata_builder + if self.cascade_attn_enabled: + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + scheduler_output. + num_common_prefix_blocks[kv_cache_group_id], + kv_cache_group_spec.kv_cache_spec, + builder, + ) - attn_metadata[layer_name] = attn_metadata_i - - # Hack for now to fix chunked local attention + no hybrid kv cache - # manager we can remove this once - # https://github.com/vllm-project/vllm/pull/21588 - # is merged (i.e. properly handle different attention backends for - # the same kv_cache_spec) - if self.attention_chunk_size is not None \ - and self.scheduler_config.disable_hybrid_kv_cache_manager: - if not hasattr(self, "local_attention_layers"): - self.local_attention_layers = [] - attn_layers = get_layers_from_vllm_config( - self.vllm_config, Attention) - for layer_name, attn_module in attn_layers.items(): - if attn_module.use_irope: - self.local_attention_layers.append(layer_name) - - local_attn_metadata_i = (builder.build( - common_prefix_len=0, - common_attn_metadata=make_local_attention_virtual_batches( - self.attention_chunk_size, common_attn_metadata, - self.cache_config.block_size), + attn_metadata_i = (builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, )) - for layer_name in self.local_attention_layers: - attn_metadata[layer_name] = local_attn_metadata_i + fast_prefill_metadata = attn_metadata_i + if (self.cache_config.kv_sharing_fast_prefill + and self.kv_sharing_fast_prefill_eligible_layers): + # Dynamically create a a dataclass type that inherits + # from attention metadata type but includes additional + # fields logits_indices_padded and num_logits_indices + # which are required for prefill truncation + fast_prefill_metadata_type = ( + make_kv_sharing_fast_prefill_attention_metadata( + metadata_cls=type(attn_metadata_i), )) + fast_prefill_metadata = fast_prefill_metadata_type( + **dataclasses.asdict(attn_metadata_i), + logits_indices_padded=logits_indices_padded, + num_logits_indices=logits_indices.size(0), + ) + + for layer_name in attn_group.layer_names: + if (self.cache_config.kv_sharing_fast_prefill + and layer_name + in self.kv_sharing_fast_prefill_eligible_layers): + attn_metadata[layer_name] = fast_prefill_metadata + continue + attn_metadata[layer_name] = attn_metadata_i attention_cuda_graphs = all( - b.can_run_in_cudagraph(common_attn_metadata) - for b in self.attn_metadata_builders) + g.metadata_builder.can_run_in_cudagraph(common_attn_metadata) + for g in self._attn_group_iterator()) # Hot-Swap lora model if self.lora_config: @@ -2224,11 +2196,11 @@ def _dummy_run( block_table[kv_cache_group_id].slot_mapping[:num_tokens], causal=True) - attn_metadata_i = self.attn_metadata_builders[ - kv_cache_group_id].build_for_cudagraph_capture( - common_attn_metadata) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i + for attn_group in self.attn_groups[kv_cache_group_id]: + attn_metadata_i = attn_group.metadata_builder\ + .build_for_cudagraph_capture(common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): @@ -2560,88 +2532,100 @@ def freeze_gc(): logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) - def _initialize_single_attn_backend( - self, kv_cache_spec: KVCacheSpec, layer_names: list[str] - ) -> tuple[AttentionBackend, AttentionMetadataBuilder]: - if isinstance(kv_cache_spec, AttentionSpec): - attn_backend_i = get_attn_backend( - kv_cache_spec.head_size, - self.dtype, - kv_cache_spec.dtype, - kv_cache_spec.block_size, - self.model_config.is_attention_free, - use_mla=kv_cache_spec.use_mla, - ) - if attn_backend_i is None: - error_msg = (f"Error with get_attn_backend: " - f"{kv_cache_spec.head_size=}, " - f"{self.dtype=}, {kv_cache_spec.dtype=}, " - f"{kv_cache_spec.block_size=}, " - f"{self.model_config.is_attention_free=}, " - f"{kv_cache_spec.use_mla=}") - logger.error(error_msg) - raise NotImplementedError( - "Non-Attention backend is not supported by V1 " - "GPUModelRunner.") - elif isinstance(kv_cache_spec, MambaSpec): - attn_backend_i = get_mamba_attn_backend(kv_cache_spec.mamba_type) - else: - raise ValueError( - f"Unknown KV cache spec type: {type(kv_cache_spec)}") - - attn_metadata_builder_i = attn_backend_i.get_builder_cls()( - kv_cache_spec, - layer_names, - self.vllm_config, - self.device, - ) - - if self.full_cuda_graph: - if attn_metadata_builder_i.attn_cudagraph_support == \ - AttentionCGSupport.NEVER: - raise ValueError(f"Full CUDAGraph not supported for " - f"{attn_backend_i.__name__}. Turn off " - f"CompilationConfig.full_cuda_graph or use a " - f" different attention backend.") - if attn_metadata_builder_i.attn_cudagraph_support == \ - AttentionCGSupport.PURE_DECODE_ONLY: - # Limit the max cudagraph size to the max number of - # sequences for pure decode only cudagraph backend, - # whose max_query_len is 1. - self.cudagraph_batch_sizes = [ - size for size in self.cudagraph_batch_sizes - if size <= self.scheduler_config.max_num_seqs - ] - return attn_backend_i, attn_metadata_builder_i - def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. """ - assert len(self.attn_backends) == 0 and len( - self.attn_metadata_builders - ) == 0, "Attention backends are already initialized" - for i, kv_cache_group_spec in enumerate( - kv_cache_config.kv_cache_groups): + assert len(self.attn_groups) == 0, \ + "Attention backends are already initialized" + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + + def get_attn_backends_for_layers( + layer_names: list[str] + ) -> dict[type[AttentionBackend], list[str]]: + attn_backends = {} + attn_backend_layers = defaultdict(list) + # Dedupe based on full class name; this is a bit safer than using + # using the class itself as the key because when we create dynamic + # attention backend subclasses (e.g. ChunkedLocalAttention) unless + # they are cached correctly, there will be different objects per + # layer. + for layer_name in layer_names: + attn_backend = attn_layers[layer_name].get_attn_backend() + key = attn_backend.full_cls_name() + attn_backends[key] = attn_backend + attn_backend_layers[key].append(layer_name) + return { + attn_backends[k]: v + for k, v in attn_backend_layers.items() + } + + def create_attn_groups( + attn_backends_map: dict[AttentionBackend, list[str]], + kv_cache_spec: KVCacheSpec, + ) -> list[AttentionGroup]: + attn_groups: list[AttentionGroup] = [] + for attn_backend, layer_names in attn_backends_map.items(): + attn_metadata_builder_i = attn_backend.get_builder_cls()( + kv_cache_spec, + layer_names, + self.vllm_config, + self.device, + ) + attn_group = AttentionGroup(attn_backend, + attn_metadata_builder_i, + layer_names) + attn_groups.append(attn_group) + + if self.full_cuda_graph: + if attn_metadata_builder_i.attn_cudagraph_support == \ + AttentionCGSupport.NEVER: + raise ValueError( + f"Full CUDAGraph not supported for " + f"{attn_backend.__name__}. Turn off " + f"CompilationConfig.full_cuda_graph or use a " + f" different attention backend.") + if attn_metadata_builder_i.attn_cudagraph_support == \ + AttentionCGSupport.PURE_DECODE_ONLY: + # Limit the max cudagraph size to the max number of + # sequences for pure decode only cudagraph backend, + # whose max_query_len is 1. + self.cudagraph_batch_sizes = [ + size for size in self.cudagraph_batch_sizes + if size <= self.scheduler_config.max_num_seqs + ] + + return attn_groups + + for kv_cache_group_spec in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group_spec.kv_cache_spec + if isinstance(kv_cache_spec, AttentionSpec): + attn_backends = get_attn_backends_for_layers( + kv_cache_group_spec.layer_names) + # TODO(lucas): move `get_mamba_attn_backend` into the mamba + # layers like above + elif isinstance(kv_cache_spec, MambaSpec): + attn_backends = { + get_mamba_attn_backend(kv_cache_spec.mamba_type): + kv_cache_group_spec.layer_names + } + else: + raise ValueError( + f"Unknown KV cache spec type: {type(kv_cache_spec)}") - attn_backend_i, attn_metadata_builder_i = ( - self._initialize_single_attn_backend( - kv_cache_spec, kv_cache_group_spec.layer_names)) - self.attn_backends.append(attn_backend_i) - self.attn_metadata_builders.append(attn_metadata_builder_i) + self.attn_groups.append( + create_attn_groups(attn_backends, kv_cache_spec)) # Calculate reorder batch threshold (if neeeded) self.calculate_reorder_batch_threshold() - if len(self.attn_backends) > 0: + if len(self.attn_groups) > 0: return # Check if model is encoder-only block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla attn_specs = list[AttentionSpec]() - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for attn_module in attn_layers.values(): if attn_module.attn_type == AttentionType.ENCODER_ONLY: @@ -2661,11 +2645,10 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: assert len(attn_specs) == len(attn_layers), \ "All or none of the layers are expected to be encoder-only" - attn_backend, attn_metadata_builder = ( - self._initialize_single_attn_backend(attn_specs[0], - attn_layers.keys())) - self.attn_backends.append(attn_backend) - self.attn_metadata_builders.append(attn_metadata_builder) + attn_backends = get_attn_backends_for_layers(attn_layers.keys()) + + self.attn_groups.append( + create_attn_groups(attn_backends, attn_specs[0])) self.is_encoder_only_model = True def calculate_reorder_batch_threshold(self) -> None: @@ -2673,7 +2656,9 @@ def calculate_reorder_batch_threshold(self) -> None: Check that if any backends reorder batches; that the reordering is compatible (e.g., decode threshold is the same) """ - for attn_metadata_builder_i in self.attn_metadata_builders: + for group in self._attn_group_iterator(): + attn_metadata_builder_i = group.metadata_builder + # check that if any backends reorder batches; that the reordering # is compatible (e.g., decode threshold is the same) reorder_batch_threshold_i = ( @@ -2747,6 +2732,18 @@ def _allocate_kv_cache_tensors( )), "Some layers are not correctly initialized" return kv_cache_raw_tensors + def _attn_group_iterator(self) -> Iterator[AttentionGroup]: + return itertools.chain.from_iterable(self.attn_groups) + + def _kv_cache_spec_attn_group_iterator( + self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]: + if not self.kv_cache_config.kv_cache_groups: + return + for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups): + for attn_group in attn_groups: + yield self.kv_cache_config.kv_cache_groups[ + kv_cache_spec_id].kv_cache_spec, attn_group + def _reshape_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, @@ -2765,23 +2762,22 @@ def _reshape_kv_cache_tensors( """ kv_caches: dict[str, torch.Tensor] = {} has_attn, has_mamba = False, False - for i, kv_cache_group_spec in enumerate( - kv_cache_config.kv_cache_groups): - kv_cache_spec = kv_cache_group_spec.kv_cache_spec - for layer_name in kv_cache_group_spec.layer_names: + for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): + attn_backend = group.backend + for layer_name in group.layer_names: raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes) if isinstance(kv_cache_spec, AttentionSpec): has_attn = True - kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( + kv_cache_shape = attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = self.attn_backends[ - i].get_kv_cache_stride_order() + kv_cache_stride_order = \ + attn_backend.get_kv_cache_stride_order() assert len(kv_cache_stride_order) == len( kv_cache_shape) except (AttributeError, NotImplementedError): @@ -2845,15 +2841,14 @@ def _verify_hybrid_attention_mamba_layout( kv_cache_raw_tensors: The KV cache buffer of each layer. """ - for i, kv_cache_group_spec in enumerate( - kv_cache_config.kv_cache_groups): - kv_cache_spec = kv_cache_group_spec.kv_cache_spec - for layer_name in kv_cache_group_spec.layer_names: + for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): + for layer_name in group.layer_names: raw_tensor = kv_cache_raw_tensors[layer_name] num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes) if isinstance(kv_cache_spec, AttentionSpec): - kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( + + kv_cache_shape = group.backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) if kv_cache_shape[0] != num_blocks or kv_cache_shape[ @@ -2888,6 +2883,7 @@ def initialize_kv_cache_tensors( self.shared_kv_cache_layers, kv_cache_config.kv_cache_groups, kv_caches, + self.attn_groups, ) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) @@ -2953,9 +2949,9 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: continue # TODO: Support other attention modules, e.g., cross-attention + # TODO(lucas): move the attention specs into the model layers like + # the attention backends if attn_module.attn_type == AttentionType.DECODER: - use_local_attention = (self.attention_chunk_size is not None - and attn_module.use_irope) if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, @@ -2964,10 +2960,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, use_mla=use_mla) - assert not use_local_attention, ( - "attention module can not be with ", - "both local attention and sliding window") - elif use_local_attention: + elif self.attention_chunk_size is not None \ + and isinstance(attn_module, ChunkedLocalAttention): kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, @@ -3038,7 +3032,7 @@ def _build_encoder_only_attn_metadata( # Use the first attention metadata builder # to create encoder attention metadata - builder = self.attn_metadata_builders[0] + builder = self.attn_groups[0][0].metadata_builder dummy_block_table = torch.zeros((num_reqs, 1), dtype=torch.int32, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5f3188efdb24..81252f9b606a 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -15,8 +15,9 @@ import torch_xla.runtime as xr import vllm.envs as envs +from vllm.attention import Attention from vllm.attention.backends.abstract import AttentionType -from vllm.attention.layer import Attention +from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import (ParallelConfig, VllmConfig, get_layers_from_vllm_config, update_config) @@ -518,7 +519,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: continue if attn_module.attn_type == AttentionType.DECODER: - if attn_module.use_irope: + if isinstance(attn_module, ChunkedLocalAttention): logger.warning_once( "Using irope in Pallas is not supported yet, it " "will fall back to global attention for long context.") diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 6761b3c5e41d..e7079235d651 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -1,14 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict +from dataclasses import dataclass from typing import TYPE_CHECKING, Optional import torch +from vllm.attention.backends.abstract import AttentionBackend from vllm.config import ModelConfig, SchedulerConfig from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index from vllm.multimodal.registry import MultiModalRegistry +from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec @@ -122,6 +125,13 @@ def get_max_items( return max_items_per_prompt, max_items_per_batch +@dataclass +class AttentionGroup: + backend: type[AttentionBackend] + metadata_builder: AttentionMetadataBuilder + layer_names: list[str] + + def sanity_check_mm_encoder_outputs( mm_embeddings: MultiModalEmbeddings, expected_num_items: int, @@ -196,6 +206,8 @@ def initialize_kv_cache_for_kv_sharing( shared_kv_cache_layers: dict[str, str], kv_cache_groups: list[KVCacheGroupSpec], kv_caches: dict[str, torch.Tensor], + # Optional for now to avoid breaking TPU + attn_groups: Optional[list[list[AttentionGroup]]] = None, ) -> None: """ Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches` @@ -225,6 +237,15 @@ def initialize_kv_cache_for_kv_sharing( group_idx = layer_to_kv_cache_group_idx[target_layer_name] kv_cache_groups[group_idx].layer_names.append(layer_name) + if attn_groups is not None: + assert len(attn_groups[group_idx]) == 1, ( + "Only one attention group per KV cache group is supported " + "for KV-cache sharing for now.") + # TODO(lucas): I think in the future the layers that re-use a + # KV cache will be in a different attention group so we can + # remove this code from here. + attn_groups[group_idx][0].layer_names.append(layer_name) + def bind_kv_cache( kv_caches: dict[str, torch.Tensor],