From 7f63963c1d3e4a0b8040598f185e1af83ee4ad21 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 21 May 2025 17:58:02 +0200 Subject: [PATCH 1/5] Enable hybrid attention models for Transformers backend Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 18 ++++--- vllm/model_executor/models/transformers.py | 57 +++++++++++++++++++--- 2 files changed, 61 insertions(+), 14 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 3fa1db0e8390..ed65a90f73f3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -536,13 +536,16 @@ def __post_init__(self) -> None: self.model, hf_token=self.hf_token, revision=self.revision) self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype) - interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"] + # Workaround for Gemma 2 which uses interleaved sliding window + # attention, but it's not specified in its config. + if self.hf_text_config.model_type == "gemma2": + self.hf_text_config.sliding_window_pattern = 2 + sliding_window = getattr(self.hf_text_config, "sliding_window", None) - has_interleaved_attention = (sliding_window is not None) and ( - isinstance(sliding_window, list) or - (self.hf_text_config.model_type in interleaved_attn_models)) + sliding_window_pattern = getattr(self.hf_text_config, + "sliding_window_pattern", None) - if (not self.disable_sliding_window and has_interleaved_attention): + if not (self.disable_sliding_window or sliding_window_pattern is None): if (backend := envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"): sliding_window_len_min = get_min_sliding_window( @@ -1040,8 +1043,7 @@ def verify_with_parallel_config( if self.use_async_output_proc: self.use_async_output_proc = False - def get_hf_config_sliding_window( - self) -> Union[Optional[int], list[Optional[int]]]: + def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled.""" # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in @@ -1052,7 +1054,7 @@ def get_hf_config_sliding_window( return None return getattr(self.hf_text_config, "sliding_window", None) - def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]: + def get_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled. """ # If user disables sliding window, return None. diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index a8f30b2f27bf..b22d81d88abe 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -16,6 +16,7 @@ """Wrapper around `transformers` models""" import re from collections.abc import Iterable +from contextlib import nullcontext from typing import Literal, Optional, Union import torch @@ -110,6 +111,33 @@ def replace_linear_class( ) +class ConfigOverride: + """Context manager to temporarily override config attributes.""" + + def __init__(self, config: PretrainedConfig, **kwargs): + self.config = config + self.kwargs = kwargs + self.kwargs_original = {} + self.kwargs_delete = set() + + def __enter__(self): + """Override config attributes.""" + for key, value in self.kwargs.items(): + if not hasattr(self.config, key): + self.kwargs_delete.add(key) + self.kwargs_original[key] = getattr(self.config, key, None) + setattr(self.config, key, value) + return self.config + + def __exit__(self, exc_type, exc_value, traceback): + """Restore original config attributes.""" + for key, value in self.kwargs_original.items(): + if key in self.kwargs_delete: + delattr(self.config, key) + else: + setattr(self.config, key, value) + + class TransformersModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -135,8 +163,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.pp_rank = self.pp_group.rank_in_group self.tp_size = get_tensor_model_parallel_world_size() + # vLLM handles interleaved sliding window attention by creating a new + # interleaved_sliding_window attribute and deleting the sliding_window + # attribute. This breaks the constructors in Transformers so we + # temporarily add the attribute back to construct the model. + config_override = nullcontext() + if hasattr(config, "interleaved_sliding_window"): + config_override = ConfigOverride( + config, sliding_window=config.interleaved_sliding_window) + # Use meta device to delay allocating GPU tensors - with torch.device("meta"): + with torch.device("meta"), config_override: # FIXME(Isotr0py): We need to refactor this part in the future to # avoid registering an extra model layer, otherwise we will need a # weights mapper to rename weights. @@ -262,9 +299,17 @@ def create_attention_instances(self) -> dict[int, Attention]: num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) start, end = get_pp_indices(self.config.num_hidden_layers, self.pp_rank, self.pp_size) - return { - i: - Attention( + + attention_instances = {} + for i in range(start, end): + # Handle interleaved sliding window attention + sliding_window = None + if (hasattr(self.config, "interleaved_sliding_window") + and hasattr(self.config, "sliding_window_pattern") + and ((i + 1) % self.config.sliding_window_pattern > 0)): + sliding_window = self.config.interleaved_sliding_window + + attention_instances[i] = Attention( num_heads=num_heads, head_size=head_size, # NOTE: We use Llama scale as default, if it's set by @@ -273,9 +318,9 @@ def create_attention_instances(self) -> dict[int, Attention]: num_kv_heads=num_kv_heads, cache_config=self.cache_config, quant_config=self.quant_config, + per_layer_sliding_window=sliding_window, prefix=f"{i}.attn") - for i in range(start, end) - } + return attention_instances def init_buffers(self, module: nn.Module): """ From e68d0b5ffa5a1fe11c617c7747f6a7f7f70327ce Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 21 May 2025 18:07:32 +0200 Subject: [PATCH 2/5] Add a TODO Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index ed65a90f73f3..e8b0608b390a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -537,7 +537,8 @@ def __post_init__(self) -> None: self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype) # Workaround for Gemma 2 which uses interleaved sliding window - # attention, but it's not specified in its config. + # attention, but it's not specified in its config. TODO: remove this + # when Gemma 2 is fixed in Transformers. if self.hf_text_config.model_type == "gemma2": self.hf_text_config.sliding_window_pattern = 2 From 6122b18e18d2c9a7c69c4ccabd7b7a166da14b13 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 22 May 2025 12:05:48 +0200 Subject: [PATCH 3/5] Add test Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/models/test_transformers.py | 49 +++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 6e38c4c7cadb..d39b88c5df75 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -1,37 +1,50 @@ # SPDX-License-Identifier: Apache-2.0 """Test the functionality of the Transformers backend.""" +from typing import Any, Union + import pytest from vllm.platforms import current_platform from ..conftest import HfRunner, VllmRunner +from ..core.block.e2e.test_correctness_sliding_window import prep_prompts from ..utils import multi_gpu_test from .utils import check_logprobs_close def check_implementation( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], + runner_ref: type[Union[HfRunner, VllmRunner]], + runner_test: type[VllmRunner], example_prompts: list[str], model: str, + kwargs_ref: dict[str, Any] = None, + kwargs_test: dict[str, Any] = None, **kwargs, ): + if kwargs_ref is None: + kwargs_ref = {} + if kwargs_test is None: + kwargs_test = {} + max_tokens = 32 num_logprobs = 5 - with vllm_runner(model, **kwargs) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + args = (example_prompts, max_tokens, num_logprobs) + + with runner_test(model, **kwargs_test, **kwargs) as model_test: + outputs_test = model_test.generate_greedy_logprobs(*args) - with hf_runner(model) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + with runner_ref(model, **kwargs_ref) as model_ref: + if isinstance(model_ref, VllmRunner): + outputs_ref = model_ref.generate_greedy_logprobs(*args) + else: + outputs_ref = model_ref.generate_greedy_logprobs_limit(*args) check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", + outputs_0_lst=outputs_ref, + outputs_1_lst=outputs_test, + name_0="ref", + name_1="test", ) @@ -58,6 +71,18 @@ def test_models( model_impl=model_impl) +def test_hybrid_attention(vllm_runner: type[VllmRunner]) -> None: + prompts, _, _ = prep_prompts(4, (800, 801)) + kwargs_ref = {"max_model_len": 8192, "enforce_eager": True} + kwargs_test = {"model_impl": "transformers", **kwargs_ref} + check_implementation(vllm_runner, + vllm_runner, + prompts, + model="hmellor/tiny-random-Gemma2ForCausalLM", + kwargs_ref=kwargs_ref, + kwargs_test=kwargs_test) + + @multi_gpu_test(num_gpus=2) def test_distributed( hf_runner: type[HfRunner], From e3d6a862514bc1978707084c7067e1428f1febc8 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 22 May 2025 12:24:44 +0200 Subject: [PATCH 4/5] pre-commit Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/models/test_transformers.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index d39b88c5df75..1a51b4aeab04 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Test the functionality of the Transformers backend.""" -from typing import Any, Union +from typing import Any, Optional, Union import pytest @@ -17,8 +17,8 @@ def check_implementation( runner_test: type[VllmRunner], example_prompts: list[str], model: str, - kwargs_ref: dict[str, Any] = None, - kwargs_test: dict[str, Any] = None, + kwargs_ref: Optional[dict[str, Any]] = None, + kwargs_test: Optional[dict[str, Any]] = None, **kwargs, ): if kwargs_ref is None: @@ -90,8 +90,11 @@ def test_distributed( example_prompts, ): kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2} - check_implementation(hf_runner, vllm_runner, example_prompts, - "meta-llama/Llama-3.2-1B-Instruct", **kwargs) + check_implementation(hf_runner, + vllm_runner, + example_prompts, + "meta-llama/Llama-3.2-1B-Instruct", + kwargs_test=kwargs) @pytest.mark.skipif( From 43347ae35ba50d11779965ba2f37c8d4559752af Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 22 May 2025 14:49:01 +0200 Subject: [PATCH 5/5] Update `basic.md` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/source/contributing/model/basic.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/contributing/model/basic.md b/docs/source/contributing/model/basic.md index ad31995f76be..1fa56dc4728d 100644 --- a/docs/source/contributing/model/basic.md +++ b/docs/source/contributing/model/basic.md @@ -117,7 +117,7 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m To support a model with interleaving sliding windows, we need to take care of the following details: -- Make sure [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/config.py#L308) evaluates `has_interleaved_attention` to `True` for this model, and set `self.hf_text_config.interleaved_sliding_window` to the format of interleaving sliding windows the model can understand. Then, `self.hf_text_config.sliding_window` will be deleted, and the model will be treated as a full-attention model. +- Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model. - In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171). With these two steps, interleave sliding windows should work with the model.