Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
is_block_tables_empty)
from vllm.attention.layer import Attention
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)
Expand Down Expand Up @@ -128,12 +128,10 @@ def get_per_layer_parameters(
to use during `plan`.
"""

layers = vllm_config.compilation_config.static_forward_context
layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: Dict[str, PerLayerParameters] = {}

for key, layer in layers.items():
assert isinstance(layer, Attention)

impl = layer.impl
assert isinstance(impl, FlashInferImpl)

Expand Down
16 changes: 15 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3441,7 +3441,8 @@ def model_post_init(self, __context: Any) -> None:
compilation_time: float = PrivateAttr

# Per-model forward context
# Map from layer name to the attention cls
# Map from layer name to layer objects that need to be accessed outside
# model code, e.g., Attention, FusedMOE when dp_size>1.
static_forward_context: dict[str, Any] = PrivateAttr

def compute_hash(self) -> str:
Expand Down Expand Up @@ -4075,3 +4076,16 @@ def assert_hashable(text):
f"vLLM tried to hash some configs that may have Python objects ids "
f"in them. This is a bug, please file an issue. "
f"Text being hashed: {text}")


T = TypeVar("T")


def get_layers_from_vllm_config(vllm_config: VllmConfig,
layer_type: type[T]) -> dict[str, T]:
return {
layer_name: layer
for layer_name, layer in
vllm_config.compilation_config.static_forward_context.items()
if isinstance(layer, layer_type)
}
Comment on lines +4084 to +4091
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any caveats to the static forward context? Can we always use this or only sometimes? (I'm not sure) Could you add a comment if there's anything folks should be careful of?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the comment to show that there may be different types of layers in forward context.

7 changes: 3 additions & 4 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config import (VllmConfig, get_current_vllm_config,
get_layers_from_vllm_config)
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention

Expand Down Expand Up @@ -81,12 +82,10 @@ def get_per_layer_parameters(
to use during `plan`.
"""

layers = vllm_config.compilation_config.static_forward_context
layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: dict[str, PerLayerParameters] = {}

for key, layer in layers.items():
assert isinstance(layer, Attention)

impl = layer.impl
assert isinstance(impl, FlashInferImpl)

Expand Down
15 changes: 5 additions & 10 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@

from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.parallel_state import get_pp_group, graph_capture
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY
Expand Down Expand Up @@ -1736,17 +1736,12 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
format. Layers that do not need KV cache are not included.
"""

forward_ctx = self.vllm_config.compilation_config.static_forward_context
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE):
continue

# TODO: Support other attention modules, e.g., sliding window,
# cross-attention
assert isinstance(attn_module, Attention)
for layer_name, attn_module in layers.items():
# TODO: Support other attention modules, e.g., cross-attention
if attn_module.attn_type == AttentionType.DECODER:
if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec(
Expand Down
7 changes: 3 additions & 4 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
Expand Down Expand Up @@ -430,11 +430,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
format. Layers that do not need KV cache are not included.
"""

forward_ctx = self.vllm_config.compilation_config.static_forward_context
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
assert isinstance(attn_module, Attention)
for layer_name, attn_module in layers.items():
if attn_module.attn_type == AttentionType.DECODER:
if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec(
Expand Down