-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
[MM][Core] Decouple ViT backend from LM backend #27061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
120a937
6a588c0
8c5b40e
474dd7f
5e424d6
74af03a
cecc4ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import pytest | ||
|
|
||
| from vllm.attention.backends.registry import _Backend | ||
| from vllm.config.multimodal import MultiModalConfig | ||
|
|
||
|
|
||
| def test_mm_encoder_attn_backend_str_conversion(): | ||
| config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN") | ||
| assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN | ||
|
|
||
|
|
||
| def test_mm_encoder_attn_backend_invalid(): | ||
| with pytest.raises(ValueError): | ||
| MultiModalConfig(mm_encoder_attn_backend="not_a_backend") | ||
|
|
||
|
|
||
| def test_mm_encoder_attn_backend_hash_updates(): | ||
| base_hash = MultiModalConfig().compute_hash() | ||
| overridden_hash = MultiModalConfig( | ||
| mm_encoder_attn_backend=_Backend.FLASH_ATTN | ||
| ).compute_hash() | ||
| assert base_hash != overridden_hash |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,13 +3,18 @@ | |||||||||||||||||||
|
|
||||||||||||||||||||
| import hashlib | ||||||||||||||||||||
| from collections.abc import Mapping | ||||||||||||||||||||
| from typing import Any, Literal, TypeAlias | ||||||||||||||||||||
| from typing import TYPE_CHECKING, Any, Literal, TypeAlias | ||||||||||||||||||||
|
|
||||||||||||||||||||
| from pydantic import ConfigDict, Field, field_validator, model_validator | ||||||||||||||||||||
| from pydantic.dataclasses import dataclass | ||||||||||||||||||||
|
|
||||||||||||||||||||
| from vllm.config.utils import config | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if TYPE_CHECKING: | ||||||||||||||||||||
| from vllm.attention.backends.registry import _Backend | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| _Backend = Any | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| @dataclass | ||||||||||||||||||||
| class BaseDummyOptions: | ||||||||||||||||||||
|
|
@@ -112,6 +117,10 @@ class MultiModalConfig: | |||||||||||||||||||
| DP (which is controlled by `--data-parallel-size`). | ||||||||||||||||||||
| This is only supported on a per-model basis and falls back to | ||||||||||||||||||||
| `"weights"` if the encoder does not support DP.""" | ||||||||||||||||||||
| mm_encoder_attn_backend: _Backend | None = None | ||||||||||||||||||||
| """Optional override for the multi-modal encoder attention backend when | ||||||||||||||||||||
| using vision transformers. Accepts any value from | ||||||||||||||||||||
| `vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`).""" | ||||||||||||||||||||
| interleave_mm_strings: bool = False | ||||||||||||||||||||
| """Enable fully interleaved support for multimodal prompts, while using | ||||||||||||||||||||
| --chat-template-content-format=string.""" | ||||||||||||||||||||
|
|
@@ -148,6 +157,29 @@ def _validate_limit_per_prompt( | |||||||||||||||||||
| value[k] = BaseDummyOptions(**v) | ||||||||||||||||||||
| return value | ||||||||||||||||||||
|
|
||||||||||||||||||||
| @field_validator("mm_encoder_attn_backend", mode="before") | ||||||||||||||||||||
| @classmethod | ||||||||||||||||||||
| def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None: | ||||||||||||||||||||
| from vllm.attention.backends.registry import ( | ||||||||||||||||||||
| _Backend as BackendEnum, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| from vllm.attention.backends.registry import ( | ||||||||||||||||||||
| backend_name_to_enum, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if value is None or isinstance(value, BackendEnum): | ||||||||||||||||||||
| return value | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if isinstance(value, str): | ||||||||||||||||||||
| candidate = backend_name_to_enum(value.upper()) | ||||||||||||||||||||
| if candidate is not None: | ||||||||||||||||||||
| return candidate | ||||||||||||||||||||
|
|
||||||||||||||||||||
| valid_backends = ", ".join(sorted(BackendEnum.__members__.keys())) | ||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps we can add a
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea right now it'll just show all possible Lines 203 to 211 in 9fce7be
I think we can shrink this selection by just having a specific
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ywang96 Yes. I like this idea, we should separate out the ViT attention backend enums |
||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||
| f"Invalid mm encoder attention backend. Expected one of: {valid_backends}." | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| @model_validator(mode="after") | ||||||||||||||||||||
| def _validate_multimodal_config(self): | ||||||||||||||||||||
| if self.mm_processor_cache_type != "shm" and ( | ||||||||||||||||||||
|
|
@@ -172,9 +204,11 @@ def compute_hash(self) -> str: | |||||||||||||||||||
| excluding anything before input ids/embeddings and after | ||||||||||||||||||||
| the final hidden states. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| # no factors to consider. | ||||||||||||||||||||
| # this config will not affect the computation graph. | ||||||||||||||||||||
| factors: list[Any] = [] | ||||||||||||||||||||
| factors: list[Any] = [ | ||||||||||||||||||||
| self.mm_encoder_attn_backend.name | ||||||||||||||||||||
| if self.mm_encoder_attn_backend is not None | ||||||||||||||||||||
| else None | ||||||||||||||||||||
| ] | ||||||||||||||||||||
| hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() | ||||||||||||||||||||
| return hash_str | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we rename this layer to
VisionAttentionbtw?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This layer will be renamed to
MMEncoderAttentionin #27147. But we can also rename it here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do that in the other PR then