diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 718b15e58785..1d78295a9781 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -37,7 +37,7 @@ is_block_tables_empty) from vllm.attention.layer import Attention from vllm.attention.ops.paged_attn import PagedAttention -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) @@ -128,12 +128,10 @@ def get_per_layer_parameters( to use during `plan`. """ - layers = vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(vllm_config, Attention) per_layer_params: Dict[str, PerLayerParameters] = {} for key, layer in layers.items(): - assert isinstance(layer, Attention) - impl = layer.impl assert isinstance(impl, FlashInferImpl) diff --git a/vllm/config.py b/vllm/config.py index 0ac3cc46b063..24f2762d0002 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3441,7 +3441,8 @@ def model_post_init(self, __context: Any) -> None: compilation_time: float = PrivateAttr # Per-model forward context - # Map from layer name to the attention cls + # Map from layer name to layer objects that need to be accessed outside + # model code, e.g., Attention, FusedMOE when dp_size>1. static_forward_context: dict[str, Any] = PrivateAttr def compute_hash(self) -> str: @@ -4075,3 +4076,16 @@ def assert_hashable(text): f"vLLM tried to hash some configs that may have Python objects ids " f"in them. This is a bug, please file an issue. " f"Text being hashed: {text}") + + +T = TypeVar("T") + + +def get_layers_from_vllm_config(vllm_config: VllmConfig, + layer_type: type[T]) -> dict[str, T]: + return { + layer_name: layer + for layer_name, layer in + vllm_config.compilation_config.static_forward_context.items() + if isinstance(layer, layer_type) + } diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 17341ecfa4fe..bce446bd2b82 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -14,7 +14,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.attention.layer import Attention -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config import (VllmConfig, get_current_vllm_config, + get_layers_from_vllm_config) from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention @@ -81,12 +82,10 @@ def get_per_layer_parameters( to use during `plan`. """ - layers = vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(vllm_config, Attention) per_layer_params: dict[str, PerLayerParameters] = {} for key, layer in layers.items(): - assert isinstance(layer, Attention) - impl = layer.impl assert isinstance(impl, FlashInferImpl) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7910481762ef..29104f70b5c0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,13 +12,13 @@ from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention -from vllm.config import CompilationLevel, VllmConfig +from vllm.config import (CompilationLevel, VllmConfig, + get_layers_from_vllm_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY @@ -1736,17 +1736,12 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - forward_ctx = self.vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(self.vllm_config, Attention) block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): - if isinstance(attn_module, FusedMoE): - continue - - # TODO: Support other attention modules, e.g., sliding window, - # cross-attention - assert isinstance(attn_module, Attention) + for layer_name, attn_module in layers.items(): + # TODO: Support other attention modules, e.g., cross-attention if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index e9cb0dbe8b5e..20a3b60172c2 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -17,7 +17,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -430,11 +430,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - forward_ctx = self.vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(self.vllm_config, Attention) block_size = self.vllm_config.cache_config.block_size kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): - assert isinstance(attn_module, Attention) + for layer_name, attn_module in layers.items(): if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec(