Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/config/test_multimodal_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest

from vllm.attention.backends.registry import _Backend
from vllm.config.multimodal import MultiModalConfig


def test_mm_encoder_attn_backend_str_conversion():
config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN")
assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN


def test_mm_encoder_attn_backend_invalid():
with pytest.raises(ValueError):
MultiModalConfig(mm_encoder_attn_backend="not_a_backend")


def test_mm_encoder_attn_backend_hash_updates():
base_hash = MultiModalConfig().compute_hash()
overridden_hash = MultiModalConfig(
mm_encoder_attn_backend=_Backend.FLASH_ATTN
).compute_hash()
assert base_hash != overridden_hash
11 changes: 10 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config.multimodal import MultiModalConfig
from vllm.config.vllm import VllmConfig
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
Expand Down Expand Up @@ -443,6 +444,7 @@ def __init__(
# This has no effect, it is only here to make it easier to swap
# between Attention and MultiHeadAttention
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__()
self.num_heads = num_heads
Expand All @@ -462,7 +464,14 @@ def __init__(
dtype = torch.get_default_dtype()

# Determine the attention backend
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)
attn_backend_override = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we rename this layer to VisionAttention btw?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This layer will be renamed to MMEncoderAttention in #27147. But we can also rename it here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do that in the other PR then

if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend
backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)

# Some auto-selected backends can be upgraded
# to upstream flash attention if available.
Expand Down
5 changes: 5 additions & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@

import vllm.model_executor.layers.quantization as me_quant
import vllm.model_executor.models as me_models
from vllm.attention.backends.registry import _Backend
from vllm.config.load import LoadConfig
from vllm.config.parallel import ParallelConfig
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.v1.sample.logits_processor import LogitsProcessor
else:
PretrainedConfig = Any

_Backend = Any
me_quant = LazyLoader(
"model_executor", globals(), "vllm.model_executor.layers.quantization"
)
Expand Down Expand Up @@ -307,6 +309,7 @@ class ModelConfig:
mm_processor_cache_type: InitVar[MMCacheType | None] = None
mm_shm_cache_max_object_size_mb: InitVar[int | None] = None
mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None
mm_encoder_attn_backend: InitVar[_Backend | str | None] = None
interleave_mm_strings: InitVar[bool | None] = None
skip_mm_profiling: InitVar[bool | None] = None
video_pruning_rate: InitVar[float | None] = None
Expand Down Expand Up @@ -424,6 +427,7 @@ def __post_init__(
mm_processor_cache_type: MMCacheType | None,
mm_shm_cache_max_object_size_mb: int | None,
mm_encoder_tp_mode: MMEncoderTPMode | None,
mm_encoder_attn_backend: _Backend | str | None,
interleave_mm_strings: bool | None,
skip_mm_profiling: bool | None,
video_pruning_rate: float | None,
Expand Down Expand Up @@ -733,6 +737,7 @@ def _task_to_convert(task: TaskOption) -> ConvertType:
mm_processor_cache_type=mm_processor_cache_type,
mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb,
mm_encoder_tp_mode=mm_encoder_tp_mode,
mm_encoder_attn_backend=mm_encoder_attn_backend,
interleave_mm_strings=interleave_mm_strings,
skip_mm_profiling=skip_mm_profiling,
video_pruning_rate=video_pruning_rate,
Expand Down
42 changes: 38 additions & 4 deletions vllm/config/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@

import hashlib
from collections.abc import Mapping
from typing import Any, Literal, TypeAlias
from typing import TYPE_CHECKING, Any, Literal, TypeAlias

from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass

from vllm.config.utils import config

if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
else:
_Backend = Any


@dataclass
class BaseDummyOptions:
Expand Down Expand Up @@ -112,6 +117,10 @@ class MultiModalConfig:
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP."""
mm_encoder_attn_backend: _Backend | None = None
"""Optional override for the multi-modal encoder attention backend when
using vision transformers. Accepts any value from
`vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`)."""
interleave_mm_strings: bool = False
"""Enable fully interleaved support for multimodal prompts, while using
--chat-template-content-format=string."""
Expand Down Expand Up @@ -148,6 +157,29 @@ def _validate_limit_per_prompt(
value[k] = BaseDummyOptions(**v)
return value

@field_validator("mm_encoder_attn_backend", mode="before")
@classmethod
def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None:
from vllm.attention.backends.registry import (
_Backend as BackendEnum,
)
from vllm.attention.backends.registry import (
backend_name_to_enum,
)

if value is None or isinstance(value, BackendEnum):
return value

if isinstance(value, str):
candidate = backend_name_to_enum(value.upper())
if candidate is not None:
return candidate

valid_backends = ", ".join(sorted(BackendEnum.__members__.keys()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can add a supported_vit_backend for Platform interface in a following PR to detect invalid backend for specific platform before initializing model.

Copy link
Member Author

@ywang96 ywang96 Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea right now it'll just show all possible _Backend with someone of them get auto resolved inside their correspending platform.get_vit_attn_backend. For example

vllm/vllm/platforms/rocm.py

Lines 203 to 211 in 9fce7be

@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
return _Backend.ROCM_AITER_FA
if on_gfx9():
return _Backend.FLASH_ATTN
return _Backend.TORCH_SDPA

I think we can shrink this selection by just having a specific _MHA_Backend enum

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ywang96 Yes. I like this idea, we should separate out the ViT attention backend enums

raise ValueError(
f"Invalid mm encoder attention backend. Expected one of: {valid_backends}."
)

@model_validator(mode="after")
def _validate_multimodal_config(self):
if self.mm_processor_cache_type != "shm" and (
Expand All @@ -172,9 +204,11 @@ def compute_hash(self) -> str:
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
factors: list[Any] = [
self.mm_encoder_attn_backend.name
if self.mm_encoder_attn_backend is not None
else None
]
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str

Expand Down
9 changes: 9 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from typing_extensions import TypeIs, deprecated

import vllm.envs as envs
from vllm.attention.backends.registry import _Backend
from vllm.config import (
CacheConfig,
CompilationConfig,
Expand Down Expand Up @@ -451,6 +452,9 @@ class EngineArgs:
MultiModalConfig.mm_shm_cache_max_object_size_mb
)
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
mm_encoder_attn_backend: _Backend | str | None = (
MultiModalConfig.mm_encoder_attn_backend
)
io_processor_plugin: str | None = None
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
video_pruning_rate: float = MultiModalConfig.video_pruning_rate
Expand Down Expand Up @@ -914,6 +918,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
multimodal_group.add_argument(
"--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]
)
multimodal_group.add_argument(
"--mm-encoder-attn-backend",
**multimodal_kwargs["mm_encoder_attn_backend"],
)
multimodal_group.add_argument(
"--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]
)
Expand Down Expand Up @@ -1160,6 +1168,7 @@ def create_model_config(self) -> ModelConfig:
mm_processor_cache_type=self.mm_processor_cache_type,
mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
mm_encoder_attn_backend=self.mm_encoder_attn_backend,
pooler_config=self.pooler_config,
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern,
Expand Down
19 changes: 17 additions & 2 deletions vllm/model_executor/models/dots_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()

Expand Down Expand Up @@ -288,7 +289,9 @@ def __init__(
)
# Select attention backend
self.attn_backend = get_vit_attn_backend(
self.hidden_size_per_attention_head, torch.get_default_dtype()
self.hidden_size_per_attention_head,
torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False

Expand Down Expand Up @@ -510,6 +513,7 @@ def __init__(
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
):
super().__init__()

Expand All @@ -521,6 +525,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
self.mlp = DotsSwiGLUFFN(
Expand Down Expand Up @@ -561,6 +566,7 @@ def __init__(
require_post_norm: bool | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
self.config = config
Expand All @@ -571,7 +577,9 @@ def __init__(
head_dim = config.embed_dim // config.num_attention_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype()
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
Expand All @@ -591,6 +599,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.blocks.{i}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
for i in range(num_layers)
]
Expand Down Expand Up @@ -750,11 +759,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config.vision_config = vision_config
else:
vision_config = self.config.vision_config
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.vision_tower = DotsVisionTransformer(
vision_config,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
vllm_config=vllm_config,
Expand Down
16 changes: 15 additions & 1 deletion vllm/model_executor/models/ernie45_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def __init__(
projection_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
Expand Down Expand Up @@ -196,6 +197,7 @@ def __init__(
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)

self.use_upstream_fa = False
Expand Down Expand Up @@ -367,6 +369,7 @@ def __init__(
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()

Expand All @@ -382,6 +385,7 @@ def __init__(
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_backend_override=attn_backend_override,
)

self.mlp = Ernie4_5_VisionMLP(
Expand Down Expand Up @@ -458,6 +462,7 @@ def __init__(
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
patch_size = vision_config.patch_size
Expand Down Expand Up @@ -493,6 +498,7 @@ def __init__(
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
]
Expand All @@ -504,7 +510,9 @@ def __init__(
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)

self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype()
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
Expand Down Expand Up @@ -1327,11 +1335,17 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
self.config = config
self.multimodal_config = multimodal_config

attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.vision_model = Ernie4_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"),
attn_backend_override=attn_backend_override,
)

self.language_model = Ernie4_5_VLMoeForCausalLM(
Expand Down
Loading