Skip to content

Commit d7e34b4

Browse files
[Model] Move vision_feature_select_strategy into resolve_visual_encoder_outputs (#25938)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent ef6e0e7 commit d7e34b4

File tree

12 files changed

+155
-179
lines changed

12 files changed

+155
-179
lines changed

tests/models/test_vision.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
@pytest.mark.parametrize(
21-
("feature_sample_layers", "num_layers_loaded", "max_possible_layers",
21+
("select_layers", "num_layers_loaded", "max_possible_layers",
2222
"expected_features"),
2323
[
2424
# All layers loaded
@@ -28,8 +28,8 @@
2828
([1, 10], 10, 20, [1, 10]),
2929
([-20, -11], 10, 20, [1, 10]),
3030
])
31-
def test_resolve_visual_encoder_outputs(feature_sample_layers,
32-
num_layers_loaded, max_possible_layers,
31+
def test_resolve_visual_encoder_outputs(select_layers, num_layers_loaded,
32+
max_possible_layers,
3333
expected_features):
3434
"""
3535
Test that offsets are correctly handled for vision feature layers.
@@ -39,9 +39,10 @@ def test_resolve_visual_encoder_outputs(feature_sample_layers,
3939
]
4040
output_tensor = resolve_visual_encoder_outputs(
4141
encoder_outputs=encoder_outputs,
42-
feature_sample_layers=feature_sample_layers,
4342
post_layer_norm=None,
44-
max_possible_layers=max_possible_layers)
43+
select_layers=select_layers,
44+
max_possible_layers=max_possible_layers,
45+
)
4546
assert torch.equal(torch.tensor(expected_features), output_tensor)
4647

4748

vllm/model_executor/models/aya_vision.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
PromptUpdateDetails)
2828
from vllm.multimodal.profiling import BaseDummyInputsBuilder
2929
from vllm.sequence import IntermediateTensors
30-
from vllm.utils.jsontree import json_map_leaves
3130
from vllm.utils.tensor_schema import TensorSchema, TensorShape
3231

3332
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@@ -350,29 +349,11 @@ def _image_pixels_to_features(
350349
self,
351350
vision_tower: SiglipVisionModel,
352351
pixel_values: torch.Tensor,
353-
**kwargs,
354352
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
355-
target_dtype: torch.dtype = \
356-
vision_tower.get_input_embeddings().weight.dtype
357-
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
358-
vision_tower(pixel_values.to(dtype=target_dtype), **kwargs)
359-
360-
def select_features(leaf: torch.Tensor):
361-
return self._select_image_features(
362-
leaf,
363-
strategy=self.config.vision_feature_select_strategy,
364-
)
365-
366-
return json_map_leaves(select_features, image_features)
367-
368-
def _select_image_features(self, image_features: torch.Tensor, *,
369-
strategy: str) -> torch.Tensor:
370-
if strategy == "default":
371-
return image_features[:, 1:]
372-
elif strategy == "full":
373-
return image_features
374-
375-
raise ValueError(f"Unexpected select feature strategy: {strategy}")
353+
return vision_tower(
354+
pixel_values.to(dtype=vision_tower.dtype),
355+
feature_select_strategy=self.config.vision_feature_select_strategy,
356+
)
376357

377358
def _process_image_input(self, image_input: AyaVisionImagePixelInputs,
378359
**kwargs) -> list[torch.Tensor]:

vllm/model_executor/models/clip.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2020
from vllm.model_executor.models.interfaces import SupportsQuant
2121

22-
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
22+
from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy,
23+
resolve_visual_encoder_outputs)
2324

2425

2526
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
@@ -308,24 +309,29 @@ def __init__(
308309
def forward(
309310
self,
310311
pixel_values: torch.Tensor,
311-
feature_sample_layers: Optional[list[int]] = None,
312+
*,
313+
select_layers: Optional[list[int]] = None,
314+
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
312315
) -> torch.Tensor:
313316

314317
hidden_states = self.embeddings(pixel_values)
315318
hidden_states = self.pre_layrnorm(hidden_states)
316319

317-
return_all_hidden_states = feature_sample_layers is not None
318-
319320
# Produces either the last layer output or all of the hidden states,
320-
# depending on if we have feature_sample_layers or not
321+
# depending on if we have select_layers or not
321322
encoder_outputs = self.encoder(
322323
inputs_embeds=hidden_states,
323-
return_all_hidden_states=return_all_hidden_states)
324+
return_all_hidden_states=select_layers is not None,
325+
)
324326

325327
# Handle post-norm (if applicable) and stacks feature layers if needed
326328
encoder_outputs = resolve_visual_encoder_outputs(
327-
encoder_outputs, feature_sample_layers, self.post_layernorm,
328-
self.config.num_hidden_layers)
329+
encoder_outputs,
330+
self.post_layernorm,
331+
select_layers=select_layers,
332+
max_possible_layers=self.config.num_hidden_layers,
333+
feature_select_strategy=feature_select_strategy,
334+
)
329335

330336
return encoder_outputs
331337

@@ -355,9 +361,14 @@ def __init__(
355361
def forward(
356362
self,
357363
pixel_values: torch.Tensor,
358-
feature_sample_layers: Optional[list[int]] = None,
364+
select_layers: Optional[list[int]] = None,
365+
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
359366
) -> torch.Tensor:
360-
return self.vision_model(pixel_values, feature_sample_layers)
367+
return self.vision_model(
368+
pixel_values,
369+
select_layers=select_layers,
370+
feature_select_strategy=feature_select_strategy,
371+
)
361372

362373
@property
363374
def device(self):

vllm/model_executor/models/llava.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
PromptUpdateDetails)
3434
from vllm.multimodal.profiling import BaseDummyInputsBuilder
3535
from vllm.sequence import IntermediateTensors
36-
from vllm.utils.jsontree import json_map_leaves
3736
from vllm.utils.tensor_schema import TensorSchema, TensorShape
3837

3938
from .clip import CLIPVisionModel
@@ -604,16 +603,6 @@ def _parse_and_validate_image_input(
604603

605604
raise AssertionError("This line should be unreachable.")
606605

607-
def _select_image_features(self, image_features: torch.Tensor, *,
608-
strategy: str) -> torch.Tensor:
609-
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
610-
if strategy == "default":
611-
return image_features[:, 1:]
612-
elif strategy == "full":
613-
return image_features
614-
615-
raise ValueError(f"Unexpected select feature strategy: {strategy}")
616-
617606
def _image_pixels_to_features(
618607
self,
619608
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
@@ -622,16 +611,10 @@ def _image_pixels_to_features(
622611
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
623612
# NOTE: we skip the step to select the vision feature layer since
624613
# this is already done inside the vision tower
625-
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
626-
vision_tower(pixel_values)
627-
628-
def select_features(leaf: torch.Tensor):
629-
return self._select_image_features(
630-
leaf,
631-
strategy=self.config.vision_feature_select_strategy,
632-
)
633-
634-
return json_map_leaves(select_features, image_features)
614+
return vision_tower(
615+
pixel_values,
616+
feature_select_strategy=self.config.vision_feature_select_strategy,
617+
)
635618

636619
def _process_image_pixels(
637620
self,

vllm/model_executor/models/llava_next.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
235235
# Determine the layer up to which we will initialize the vision tower
236236
if isinstance(vision_feature_layer, int):
237237
vision_hidden_size = config.vision_config.hidden_size
238-
self.feature_sample_layers = None
238+
self.select_layers = None
239239
# Used for multimodal granite models to control encoder outputs
240240
elif isinstance(vision_feature_layer, (list, tuple)):
241241
vision_hidden_size = config.vision_config.hidden_size * len(
242242
vision_feature_layer)
243-
self.feature_sample_layers = vision_feature_layer
243+
self.select_layers = vision_feature_layer
244244
else:
245245
raise TypeError(
246246
f"vision_layer_feature type: {type(vision_feature_layer)}"
@@ -312,30 +312,17 @@ def _parse_and_validate_image_input(
312312

313313
raise AssertionError("This line should be unreachable.")
314314

315-
def _select_image_features(self, image_features: torch.Tensor, *,
316-
strategy: str) -> torch.Tensor:
317-
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
318-
if strategy == "default":
319-
return image_features[:, 1:]
320-
elif strategy == "full":
321-
return image_features
322-
323-
raise ValueError(f"Unexpected select feature strategy: {strategy}")
324-
325315
def _image_pixels_to_features(
326316
self,
327317
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
328318
pixel_values: torch.Tensor,
329319
) -> torch.Tensor:
330-
331320
# NOTE: we skip the step to select the vision feature layer since
332321
# this is already done inside the vision tower
333-
image_features = vision_tower(
334-
pixel_values, feature_sample_layers=self.feature_sample_layers)
335-
336-
return self._select_image_features(
337-
image_features,
338-
strategy=self.config.vision_feature_select_strategy,
322+
return vision_tower(
323+
pixel_values,
324+
select_layers=self.select_layers,
325+
feature_select_strategy=self.config.vision_feature_select_strategy,
339326
)
340327

341328
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py

vllm/model_executor/models/llava_next_video.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -349,27 +349,16 @@ def _parse_and_validate_video_input(
349349
"w": expected_w,
350350
})
351351

352-
def _select_image_features(self, image_features: torch.Tensor, *,
353-
strategy: str) -> torch.Tensor:
354-
if strategy == "default":
355-
return image_features[:, 1:]
356-
elif strategy == "full":
357-
return image_features
358-
359-
raise ValueError(f"Unexpected select feature strategy: {strategy}")
360-
361352
def _video_pixels_to_features(
362353
self,
363354
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
364355
pixel_values: torch.Tensor,
365356
) -> torch.Tensor:
366-
367357
# NOTE: we skip the step to select the vision feature layer since
368358
# this is already done inside the vision tower
369-
image_features = vision_tower(pixel_values)
370-
image_features = self._select_image_features(
371-
image_features,
372-
strategy=self.config.vision_feature_select_strategy,
359+
image_features = vision_tower(
360+
pixel_values,
361+
feature_select_strategy=self.config.vision_feature_select_strategy,
373362
)
374363
image_features = self.vision_resampler(image_features)
375364
image_features = self.multi_modal_projector(image_features)

vllm/model_executor/models/llava_onevision.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -577,27 +577,16 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
577577

578578
return mm_input_by_modality
579579

580-
def _select_image_features(self, image_features: torch.Tensor, *,
581-
strategy: str) -> torch.Tensor:
582-
if strategy == "default":
583-
return image_features[:, 1:]
584-
elif strategy == "full":
585-
return image_features
586-
587-
raise ValueError(f"Unexpected select feature strategy: {strategy}")
588-
589580
def _image_pixels_to_features(
590581
self,
591582
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
592583
pixel_values: torch.Tensor,
593584
) -> torch.Tensor:
594-
595585
# NOTE: we skip the step to select the vision feature layer since
596586
# this is already done inside the vision tower
597-
image_features = vision_tower(pixel_values)
598-
return self._select_image_features(
599-
image_features,
600-
strategy=self.config.vision_feature_select_strategy,
587+
return vision_tower(
588+
pixel_values,
589+
feature_select_strategy=self.config.vision_feature_select_strategy,
601590
)
602591

603592
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
@@ -750,13 +739,11 @@ def _video_pixels_to_features(
750739
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
751740
pixel_values: torch.Tensor,
752741
) -> torch.Tensor:
753-
754742
# NOTE: we skip the step to select the vision feature layer since
755743
# this is already done inside the vision tower
756-
video_features = vision_tower(pixel_values)
757-
video_features = self._select_image_features(
758-
video_features,
759-
strategy=self.config.vision_feature_select_strategy,
744+
video_features = vision_tower(
745+
pixel_values,
746+
feature_select_strategy=self.config.vision_feature_select_strategy,
760747
)
761748
video_features = self.multi_modal_projector(video_features)
762749
video_features = self.apply_pooling(video_features)

vllm/model_executor/models/minimax_vl_01.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from vllm.multimodal import MULTIMODAL_REGISTRY
1818
from vllm.multimodal.inputs import MultiModalFieldConfig
1919
from vllm.sequence import IntermediateTensors
20-
from vllm.utils.jsontree import json_map_leaves
2120
from vllm.utils.tensor_schema import TensorSchema, TensorShape
2221

2322
from .clip import CLIPVisionModel
@@ -221,15 +220,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
221220
def get_language_model(self) -> torch.nn.Module:
222221
return self.language_model
223222

224-
def _select_image_features(self, image_features: torch.Tensor, *,
225-
strategy: str) -> torch.Tensor:
226-
if strategy == "default":
227-
return image_features[:, 1:]
228-
elif strategy == "full":
229-
return image_features
230-
231-
raise ValueError(f"Unexpected select feature strategy: {strategy}")
232-
233223
def _image_pixels_to_features(
234224
self,
235225
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
@@ -238,16 +228,10 @@ def _image_pixels_to_features(
238228
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
239229
# NOTE: we skip the step to select the vision feature layer since
240230
# this is already done inside the vision tower
241-
image_features: tuple[torch.Tensor, ...] = \
242-
tuple(vision_tower(p) for p in pixel_values)
243-
244-
def select_features(leaf: torch.Tensor):
245-
return self._select_image_features(
246-
leaf,
247-
strategy=self.config.vision_feature_select_strategy,
248-
)
249-
250-
return json_map_leaves(select_features, image_features)
231+
feature_select_strategy = self.config.vision_feature_select_strategy
232+
return tuple(
233+
vision_tower(p, feature_select_strategy=feature_select_strategy)
234+
for p in pixel_values)
251235

252236
# adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631
253237
def pack_image_features(self, image_features: list[torch.Tensor],

0 commit comments

Comments
 (0)