diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index 8744bcbd3a2a..a30a856a81cf 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -18,7 +18,7 @@ @pytest.mark.parametrize( - ("feature_sample_layers", "num_layers_loaded", "max_possible_layers", + ("select_layers", "num_layers_loaded", "max_possible_layers", "expected_features"), [ # All layers loaded @@ -28,8 +28,8 @@ ([1, 10], 10, 20, [1, 10]), ([-20, -11], 10, 20, [1, 10]), ]) -def test_resolve_visual_encoder_outputs(feature_sample_layers, - num_layers_loaded, max_possible_layers, +def test_resolve_visual_encoder_outputs(select_layers, num_layers_loaded, + max_possible_layers, expected_features): """ Test that offsets are correctly handled for vision feature layers. @@ -39,9 +39,10 @@ def test_resolve_visual_encoder_outputs(feature_sample_layers, ] output_tensor = resolve_visual_encoder_outputs( encoder_outputs=encoder_outputs, - feature_sample_layers=feature_sample_layers, post_layer_norm=None, - max_possible_layers=max_possible_layers) + select_layers=select_layers, + max_possible_layers=max_possible_layers, + ) assert torch.equal(torch.tensor(expected_features), output_tensor) diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index f6dfa435ddd4..81bab5b34bc6 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -27,7 +27,6 @@ PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils.jsontree import json_map_leaves from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -350,29 +349,11 @@ def _image_pixels_to_features( self, vision_tower: SiglipVisionModel, pixel_values: torch.Tensor, - **kwargs, ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: - target_dtype: torch.dtype = \ - vision_tower.get_input_embeddings().weight.dtype - image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \ - vision_tower(pixel_values.to(dtype=target_dtype), **kwargs) - - def select_features(leaf: torch.Tensor): - return self._select_image_features( - leaf, - strategy=self.config.vision_feature_select_strategy, - ) - - return json_map_leaves(select_features, image_features) - - def _select_image_features(self, image_features: torch.Tensor, *, - strategy: str) -> torch.Tensor: - if strategy == "default": - return image_features[:, 1:] - elif strategy == "full": - return image_features - - raise ValueError(f"Unexpected select feature strategy: {strategy}") + return vision_tower( + pixel_values.to(dtype=vision_tower.dtype), + feature_select_strategy=self.config.vision_feature_select_strategy, + ) def _process_image_input(self, image_input: AyaVisionImagePixelInputs, **kwargs) -> list[torch.Tensor]: diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index dcab00822870..451da2120048 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -19,7 +19,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsQuant -from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs +from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy, + resolve_visual_encoder_outputs) class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): @@ -308,24 +309,29 @@ def __init__( def forward( self, pixel_values: torch.Tensor, - feature_sample_layers: Optional[list[int]] = None, + *, + select_layers: Optional[list[int]] = None, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, ) -> torch.Tensor: hidden_states = self.embeddings(pixel_values) hidden_states = self.pre_layrnorm(hidden_states) - return_all_hidden_states = feature_sample_layers is not None - # Produces either the last layer output or all of the hidden states, - # depending on if we have feature_sample_layers or not + # depending on if we have select_layers or not encoder_outputs = self.encoder( inputs_embeds=hidden_states, - return_all_hidden_states=return_all_hidden_states) + return_all_hidden_states=select_layers is not None, + ) # Handle post-norm (if applicable) and stacks feature layers if needed encoder_outputs = resolve_visual_encoder_outputs( - encoder_outputs, feature_sample_layers, self.post_layernorm, - self.config.num_hidden_layers) + encoder_outputs, + self.post_layernorm, + select_layers=select_layers, + max_possible_layers=self.config.num_hidden_layers, + feature_select_strategy=feature_select_strategy, + ) return encoder_outputs @@ -355,9 +361,14 @@ def __init__( def forward( self, pixel_values: torch.Tensor, - feature_sample_layers: Optional[list[int]] = None, + select_layers: Optional[list[int]] = None, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, ) -> torch.Tensor: - return self.vision_model(pixel_values, feature_sample_layers) + return self.vision_model( + pixel_values, + select_layers=select_layers, + feature_select_strategy=feature_select_strategy, + ) @property def device(self): diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 46cf93be191e..d823e5cb58d2 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -33,7 +33,6 @@ PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils.jsontree import json_map_leaves from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel @@ -604,16 +603,6 @@ def _parse_and_validate_image_input( raise AssertionError("This line should be unreachable.") - def _select_image_features(self, image_features: torch.Tensor, *, - strategy: str) -> torch.Tensor: - # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa - if strategy == "default": - return image_features[:, 1:] - elif strategy == "full": - return image_features - - raise ValueError(f"Unexpected select feature strategy: {strategy}") - def _image_pixels_to_features( self, vision_tower: Union[CLIPVisionModel, SiglipVisionModel, @@ -622,16 +611,10 @@ def _image_pixels_to_features( ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \ - vision_tower(pixel_values) - - def select_features(leaf: torch.Tensor): - return self._select_image_features( - leaf, - strategy=self.config.vision_feature_select_strategy, - ) - - return json_map_leaves(select_features, image_features) + return vision_tower( + pixel_values, + feature_select_strategy=self.config.vision_feature_select_strategy, + ) def _process_image_pixels( self, diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c4f1daaab9bf..3f7e39c02061 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -235,12 +235,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # Determine the layer up to which we will initialize the vision tower if isinstance(vision_feature_layer, int): vision_hidden_size = config.vision_config.hidden_size - self.feature_sample_layers = None + self.select_layers = None # Used for multimodal granite models to control encoder outputs elif isinstance(vision_feature_layer, (list, tuple)): vision_hidden_size = config.vision_config.hidden_size * len( vision_feature_layer) - self.feature_sample_layers = vision_feature_layer + self.select_layers = vision_feature_layer else: raise TypeError( f"vision_layer_feature type: {type(vision_feature_layer)}" @@ -312,30 +312,17 @@ def _parse_and_validate_image_input( raise AssertionError("This line should be unreachable.") - def _select_image_features(self, image_features: torch.Tensor, *, - strategy: str) -> torch.Tensor: - # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa - if strategy == "default": - return image_features[:, 1:] - elif strategy == "full": - return image_features - - raise ValueError(f"Unexpected select feature strategy: {strategy}") - def _image_pixels_to_features( self, vision_tower: Union[CLIPVisionModel, SiglipVisionModel], pixel_values: torch.Tensor, ) -> torch.Tensor: - # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower( - pixel_values, feature_sample_layers=self.feature_sample_layers) - - return self._select_image_features( - image_features, - strategy=self.config.vision_feature_select_strategy, + return vision_tower( + pixel_values, + select_layers=self.select_layers, + feature_select_strategy=self.config.vision_feature_select_strategy, ) # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index aebc661d53f8..697b8e819707 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -349,27 +349,16 @@ def _parse_and_validate_video_input( "w": expected_w, }) - def _select_image_features(self, image_features: torch.Tensor, *, - strategy: str) -> torch.Tensor: - if strategy == "default": - return image_features[:, 1:] - elif strategy == "full": - return image_features - - raise ValueError(f"Unexpected select feature strategy: {strategy}") - def _video_pixels_to_features( self, vision_tower: Union[CLIPVisionModel, SiglipVisionModel], pixel_values: torch.Tensor, ) -> torch.Tensor: - # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower(pixel_values) - image_features = self._select_image_features( - image_features, - strategy=self.config.vision_feature_select_strategy, + image_features = vision_tower( + pixel_values, + feature_select_strategy=self.config.vision_feature_select_strategy, ) image_features = self.vision_resampler(image_features) image_features = self.multi_modal_projector(image_features) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 6088195c91d5..924f8ba3585f 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -577,27 +577,16 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: return mm_input_by_modality - def _select_image_features(self, image_features: torch.Tensor, *, - strategy: str) -> torch.Tensor: - if strategy == "default": - return image_features[:, 1:] - elif strategy == "full": - return image_features - - raise ValueError(f"Unexpected select feature strategy: {strategy}") - def _image_pixels_to_features( self, vision_tower: Union[CLIPVisionModel, SiglipVisionModel], pixel_values: torch.Tensor, ) -> torch.Tensor: - # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower(pixel_values) - return self._select_image_features( - image_features, - strategy=self.config.vision_feature_select_strategy, + return vision_tower( + pixel_values, + feature_select_strategy=self.config.vision_feature_select_strategy, ) # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py @@ -750,13 +739,11 @@ def _video_pixels_to_features( vision_tower: Union[CLIPVisionModel, SiglipVisionModel], pixel_values: torch.Tensor, ) -> torch.Tensor: - # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - video_features = vision_tower(pixel_values) - video_features = self._select_image_features( - video_features, - strategy=self.config.vision_feature_select_strategy, + video_features = vision_tower( + pixel_values, + feature_select_strategy=self.config.vision_feature_select_strategy, ) video_features = self.multi_modal_projector(video_features) video_features = self.apply_pooling(video_features) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index d41b9d3f14fe..938c9a689fcf 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -17,7 +17,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.sequence import IntermediateTensors -from vllm.utils.jsontree import json_map_leaves from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel @@ -221,15 +220,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def get_language_model(self) -> torch.nn.Module: return self.language_model - def _select_image_features(self, image_features: torch.Tensor, *, - strategy: str) -> torch.Tensor: - if strategy == "default": - return image_features[:, 1:] - elif strategy == "full": - return image_features - - raise ValueError(f"Unexpected select feature strategy: {strategy}") - def _image_pixels_to_features( self, vision_tower: Union[CLIPVisionModel, SiglipVisionModel, @@ -238,16 +228,10 @@ def _image_pixels_to_features( ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features: tuple[torch.Tensor, ...] = \ - tuple(vision_tower(p) for p in pixel_values) - - def select_features(leaf: torch.Tensor): - return self._select_image_features( - leaf, - strategy=self.config.vision_feature_select_strategy, - ) - - return json_map_leaves(select_features, image_features) + feature_select_strategy = self.config.vision_feature_select_strategy + return tuple( + vision_tower(p, feature_select_strategy=feature_select_strategy) + for p in pixel_values) # adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631 def pack_image_features(self, image_features: list[torch.Tensor], diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 6344fc394833..bf451c5005b7 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -51,7 +51,8 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix -from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs +from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy, + resolve_visual_encoder_outputs) try: from xformers import ops as xops @@ -1218,7 +1219,9 @@ def __init__( def forward( self, pixel_values: list[torch.Tensor], - feature_sample_layers: Optional[list[int]] = None, + *, + select_layers: Optional[list[int]] = None, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, ) -> tuple[torch.Tensor, ...]: """ Args: @@ -1226,7 +1229,7 @@ def forward( in pixel_values. This means it will be a list of tensors because multiple requests batched can have multiple images, each with their own shape potentially - feature_sample_layers: Layer indices whose features should be + select_layers: Layer indices whose features should be concatenated and used as the visual encoder output. If none are provided, the last layer is used. @@ -1267,15 +1270,20 @@ def forward( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds) - return_all_hidden_states = feature_sample_layers is not None out = self.transformer( patch_embeds, attention_mask, position_embedding, - return_all_hidden_states=return_all_hidden_states) + return_all_hidden_states=select_layers is not None, + ) - out = resolve_visual_encoder_outputs(out, feature_sample_layers, None, - self.config.num_hidden_layers) + out = resolve_visual_encoder_outputs( + out, + None, + select_layers=select_layers, + max_possible_layers=self.config.num_hidden_layers, + feature_select_strategy=feature_select_strategy, + ) # squeeze dim 0 and split into separate tensors for each image return torch.split(out.squeeze(0), embed_sizes) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index eb49d6d2c335..4c60d96c77d7 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -23,7 +23,8 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs +from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy, + resolve_visual_encoder_outputs) class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): @@ -415,28 +416,31 @@ def __init__( def forward( self, pixel_values: torch.Tensor, - interpolate_pos_encoding: bool = True, - feature_sample_layers: Optional[list[int]] = None, + *, + interpolate_pos_encoding: bool = False, + select_layers: Optional[list[int]] = None, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, ) -> torch.Tensor: - hidden_states = self.embeddings( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, ) - return_all_hidden_states = feature_sample_layers is not None - # Produces either the last layer output or all of the hidden states, - # depending on if we have feature_sample_layers or not + # depending on if we have select_layers or not encoder_outputs = self.encoder( inputs_embeds=hidden_states, - return_all_hidden_states=return_all_hidden_states, + return_all_hidden_states=select_layers is not None, ) # Handle post-norm (if applicable) and stacks feature layers if needed encoder_outputs = resolve_visual_encoder_outputs( - encoder_outputs, feature_sample_layers, self.post_layernorm, - self.config.num_hidden_layers) + encoder_outputs, + self.post_layernorm, + select_layers=select_layers, + max_possible_layers=self.config.num_hidden_layers, + feature_select_strategy=feature_select_strategy, + ) # TODO: add this back when pooled_output is used in inference. # if self.use_head: @@ -471,16 +475,22 @@ def __init__( def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding + @property + def dtype(self): + return self.get_input_embeddings().weight.dtype + def forward( self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False, - feature_sample_layers: Optional[list[int]] = None, + select_layers: Optional[list[int]] = None, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, ) -> torch.Tensor: return self.vision_model( pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, - feature_sample_layers=feature_sample_layers, + select_layers=select_layers, + feature_select_strategy=feature_select_strategy, ) def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 1145bea41480..ed02fe2c389f 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -33,7 +33,6 @@ PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils.jsontree import json_map_leaves from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel @@ -476,30 +475,16 @@ def _parse_and_validate_image_input( raise AssertionError("This line should be unreachable.") - def _select_image_features(self, image_features: torch.Tensor, *, - strategy: str) -> torch.Tensor: - if strategy == "default": - return image_features[:, 1:] - elif strategy == "full": - return image_features - raise ValueError(f"Unexpected select feature strategy: {strategy}") - def _image_pixels_to_features( self, vision_tower: Union[CLIPVisionModel, SiglipVisionModel], pixel_values: Union[torch.Tensor, list[torch.Tensor]], ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # From vLLM LLaVA, vision tower output handling - image_hidden_states: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \ - vision_tower(pixel_values) - - def select_features_fn(leaf: torch.Tensor): - return self._select_image_features( - leaf, - strategy=self.config.vision_feature_select_strategy, - ) - - return json_map_leaves(select_features_fn, image_hidden_states) + return vision_tower( + pixel_values, + feature_select_strategy=self.config.vision_feature_select_strategy, + ) def _add_tarsier_split_tokens( self, projected_image_features: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 08ad8fbeb424..e077691fcec2 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -4,10 +4,12 @@ import itertools import math from abc import ABC, abstractmethod -from typing import Final, Generic, Literal, Optional, Protocol, TypeVar, Union +from typing import (Callable, Final, Generic, Literal, Optional, Protocol, + TypeVar, Union) import torch from transformers import PretrainedConfig +from typing_extensions import assert_never from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -86,11 +88,39 @@ def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend: return current_platform.get_vit_attn_backend(head_size, dtype) +VisionFeatureSelectStrategy = Union[ + Literal["class", "default", "full"], + Callable[[torch.Tensor], torch.Tensor], +] + + +def _get_vision_feature_selector( + strategy: VisionFeatureSelectStrategy, +) -> Callable[[torch.Tensor], torch.Tensor]: + if callable(strategy): + return strategy + + # https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/clip/modeling_clip.py#L762 + if strategy == "class": + return lambda feats: feats[:, 0, :] + + # https://github.com/huggingface/transformers/blob/4a02bc7004285bdb12cc033e87ad2578ce2fa900/src/transformers/models/llava/modeling_llava.py#L196 + if strategy == "default": + return lambda feats: feats[:, 1:, :] + + if strategy == "full": + return lambda feats: feats + + assert_never(strategy) + + def resolve_visual_encoder_outputs( encoder_outputs: Union[torch.Tensor, list[torch.Tensor]], - feature_sample_layers: Optional[list[int]], post_layer_norm: Optional[torch.nn.LayerNorm], - max_possible_layers: int, + *, + select_layers: Optional[list[int]] = None, + max_possible_layers: Optional[int] = None, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, ) -> torch.Tensor: """Given the outputs a visual encoder module that may correspond to the output of the last layer, or a list of hidden states to be stacked, @@ -98,17 +128,32 @@ def resolve_visual_encoder_outputs( Args: encoder_outputs: Output of encoder's last layer or all hidden states. - feature_sample_layers: Optional layer indices to grab from the encoder - outputs; if provided, encoder outputs must be a list. post_layer_norm: Post norm to apply to the output of the encoder. + select_layers: Optional layer indices to grab from the encoder + outputs; if provided, encoder outputs must be a list. max_possible_layers: Total layers in the fully loaded visual encoder. - + feature_select_strategy: Defines how to select the hidden states + from each layer. """ - if feature_sample_layers is None: + if select_layers is None: + if not isinstance(encoder_outputs, torch.Tensor): + raise ValueError("Expected only a single encoder output when " + "`select_layers` is not provided") + + if feature_select_strategy is not None: + select_features = _get_vision_feature_selector( + feature_select_strategy) + encoder_outputs = select_features(encoder_outputs) + if post_layer_norm is not None: return post_layer_norm(encoder_outputs) + return encoder_outputs + if max_possible_layers is None: + raise ValueError("`max_possible_layers` must be provided " + "alongside `select_layers`") + # Get the hidden states corresponding to the layer indices. # Negative values are relative to the full visual encoder, # so offset them depending on how many layers were loaded. @@ -120,13 +165,18 @@ def resolve_visual_encoder_outputs( hs_pool = [ encoder_outputs[layer_idx] if layer_idx >= 0 else encoder_outputs[layer_idx + offset] - for layer_idx in feature_sample_layers + for layer_idx in select_layers ] + if feature_select_strategy is not None: + select_features = _get_vision_feature_selector(feature_select_strategy) + hs_pool = [select_features(hs) for hs in hs_pool] + # Apply post-norm on the final hidden state if we are using it - uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1) + uses_last_layer = select_layers[-1] in (max_possible_layers - 1, -1) if post_layer_norm is not None and uses_last_layer: - hs_pool[-1] = post_layer_norm(encoder_outputs) + hs_pool[-1] = post_layer_norm(hs_pool[-1]) + return torch.cat(hs_pool, dim=-1)