Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 3 additions & 16 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix)
from .vision import get_vision_encoder_info
from .vision import get_num_selected_vision_tokens, get_vision_encoder_info


class LlavaImagePixelInputs(TensorSchema):
Expand Down Expand Up @@ -147,19 +147,6 @@ def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}

def _apply_feature_select_strategy(
self,
strategy: str,
encoder_num_image_tokens: int,
) -> int:
if strategy == "default":
return encoder_num_image_tokens - 1
if strategy == "full":
return encoder_num_image_tokens

msg = f"Unexpected feature select strategy: {strategy!r}"
raise NotImplementedError(msg)

def get_num_image_tokens(
self,
*,
Expand All @@ -169,12 +156,12 @@ def get_num_image_tokens(
hf_config = self.get_hf_config()
vision_encoder_info = self.get_vision_encoder_info()

return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
return get_num_selected_vision_tokens(
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
hf_config.vision_feature_select_strategy,
)

def get_image_size_with_most_features(self) -> ImageSize:
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix)
from .vision import get_num_selected_vision_tokens


class LlavaNextImagePixelInputs(TensorSchema):
Expand Down Expand Up @@ -95,12 +96,12 @@ def get_num_image_tokens(
hf_config = self.get_hf_config()
vision_encoder_info = self.get_vision_encoder_info()

base_feature_size = self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
base_feature_size = get_num_selected_vision_tokens(
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
hf_config.vision_feature_select_strategy,
)

num_patch_height, num_patch_width = get_anyres_image_grid_shape(
Expand Down
23 changes: 6 additions & 17 deletions vllm/model_executor/models/tarsier.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix)
from .vision import VisionEncoderInfo, get_vision_encoder_info
from .vision import (VisionEncoderInfo, get_num_selected_vision_tokens,
get_vision_encoder_info)


class TarsierImagePixelInputs(TensorSchema):
Expand Down Expand Up @@ -201,18 +202,6 @@ def get_hf_processor(self, **kwargs: object) -> TarsierProcessor:
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}

def _apply_feature_select_strategy(
self,
strategy: str,
encoder_num_image_tokens: int,
) -> int:
if strategy == "default":
return encoder_num_image_tokens - 1
if strategy == "full":
return encoder_num_image_tokens
msg = f"Unexpected feature select strategy: {strategy!r}"
raise NotImplementedError(msg)

def get_num_image_tokens(
self,
*,
Expand All @@ -221,21 +210,21 @@ def get_num_image_tokens(
) -> int:
hf_config = self.get_hf_config()
vision_encoder_info = self.get_vision_encoder_info()
num_projected_patches = self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
num_projected_patches = get_num_selected_vision_tokens(
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
hf_config.vision_feature_select_strategy,
)
if num_projected_patches <= 0:
default_size = self.get_image_size_with_most_features()
num_projected_patches_default = self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
num_projected_patches_default = get_num_selected_vision_tokens(
vision_encoder_info.get_num_image_tokens(
image_width=default_size.width,
image_height=default_size.height,
),
hf_config.vision_feature_select_strategy,
)
if num_projected_patches_default <= 0:
raise ValueError(
Expand Down
32 changes: 28 additions & 4 deletions vllm/model_executor/models/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import torch
from transformers import PretrainedConfig
from typing_extensions import assert_never

from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand All @@ -22,9 +21,13 @@
_C = TypeVar("_C", bound=PretrainedConfig)


class _RootConfig(Protocol[_C]):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous type annotation was incorrect

vision_config: _C


class VisionEncoderInfo(ABC, Generic[_C]):

def __init__(self, hf_config: _C) -> None:
def __init__(self, hf_config: _RootConfig[_C]) -> None:
super().__init__()

self.hf_config = hf_config
Expand Down Expand Up @@ -95,7 +98,7 @@ def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend:


def _get_vision_feature_selector(
strategy: VisionFeatureSelectStrategy,
strategy: Union[VisionFeatureSelectStrategy, str],
) -> Callable[[torch.Tensor], torch.Tensor]:
if callable(strategy):
return strategy
Expand All @@ -111,7 +114,28 @@ def _get_vision_feature_selector(
if strategy == "full":
return lambda feats: feats

assert_never(strategy)
raise ValueError(f"Unexpected feature select strategy: {strategy!r}")


def get_num_selected_vision_tokens(
num_vision_tokens: int,
strategy: Union[VisionFeatureSelectStrategy, str],
) -> int:
if callable(strategy):
dummy_features = torch.empty(1, num_vision_tokens, 64) # [B, L, D]
dummy_selected_features = strategy(dummy_features)
return dummy_selected_features.shape[1]

if strategy == "class":
return 1

if strategy == "default":
return num_vision_tokens - 1

if strategy == "full":
return num_vision_tokens

raise ValueError(f"Unexpected feature select strategy: {strategy!r}")


def resolve_visual_encoder_outputs(
Expand Down
Loading