Skip to content

Commit 15084b9

Browse files
Isotr0pyMu Huai
authored andcommitted
[VLM] Cleanup siglip legacy code and fix broken paligemma multimodal processor (vllm-project#14602)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
1 parent 21ac3af commit 15084b9

File tree

2 files changed

+14
-76
lines changed

2 files changed

+14
-76
lines changed

vllm/model_executor/models/paligemma.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
from vllm.sequence import IntermediateTensors
2525

2626
from .interfaces import SupportsMultiModal, SupportsPP
27-
from .siglip import SiglipVisionModel, get_max_siglip_image_tokens
27+
from .siglip import SiglipVisionModel
2828
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
2929
maybe_prefix, merge_multimodal_embeddings)
30+
from .vision import get_vision_encoder_info
3031

3132
logger = init_logger(__name__)
3233

@@ -67,6 +68,9 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
6768
def get_hf_config(self):
6869
return self.ctx.get_hf_config(PaliGemmaConfig)
6970

71+
def get_vision_encoder_info(self):
72+
return get_vision_encoder_info(self.get_hf_config())
73+
7074
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
7175
return {"image": 1}
7276

@@ -78,9 +82,8 @@ def get_mm_max_tokens_per_item(
7882
return {"image": self.get_num_image_tokens()}
7983

8084
def get_num_image_tokens(self) -> int:
81-
hf_config = self.get_hf_config()
82-
vision_config = hf_config.vision_config
83-
return get_max_siglip_image_tokens(vision_config)
85+
vision_encoder_info = self.get_vision_encoder_info()
86+
return vision_encoder_info.get_max_image_tokens()
8487

8588

8689
class PaliGemmaDummyInputsBuilder(
@@ -173,8 +176,10 @@ def apply(
173176
prompt: Union[str, list[int]],
174177
mm_data: MultiModalDataDict,
175178
hf_processor_mm_kwargs: Mapping[str, object],
179+
return_mm_hashes: bool = False,
176180
) -> MultiModalInputs:
177-
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
181+
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
182+
return_mm_hashes)
178183
prompt_token_ids = mm_inputs["prompt_token_ids"]
179184

180185
tokenizer = self.info.get_tokenizer()

vllm/model_executor/models/siglip.py

Lines changed: 4 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Iterable, Optional, Set, Tuple, Union
77

88
import torch
9-
from PIL import Image
109
from torch import nn
1110
from transformers import SiglipVisionConfig
1211

@@ -20,74 +19,10 @@
2019
from vllm.model_executor.layers.vocab_parallel_embedding import (
2120
VocabParallelEmbedding)
2221
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23-
from vllm.multimodal.utils import consecutive_placeholder_ranges
24-
from vllm.sequence import SequenceData
2522

2623
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
2724

2825

29-
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
30-
# Since interpolation is applied, the image size need not be divisible
31-
# assert image_size % patch_size == 0
32-
return image_size // patch_size
33-
34-
35-
def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int:
36-
grid_length = get_siglip_patch_grid_length(image_size=image_size,
37-
patch_size=patch_size)
38-
return grid_length * grid_length
39-
40-
41-
def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int:
42-
return get_siglip_num_patches(image_size=hf_config.image_size,
43-
patch_size=hf_config.patch_size)
44-
45-
46-
def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
47-
return get_siglip_image_feature_size(hf_config)
48-
49-
50-
def dummy_seq_data_for_siglip(
51-
hf_config: SiglipVisionConfig,
52-
seq_len: int,
53-
num_images: int,
54-
*,
55-
image_token_id: int,
56-
image_feature_size_override: Optional[int] = None,
57-
mm_key: str = "image",
58-
):
59-
if image_feature_size_override is None:
60-
image_feature_size = get_siglip_image_feature_size(hf_config)
61-
else:
62-
image_feature_size = image_feature_size_override
63-
64-
return SequenceData.from_prompt_token_counts(
65-
(image_token_id, image_feature_size * num_images),
66-
(0, seq_len - image_feature_size * num_images),
67-
), {
68-
mm_key:
69-
consecutive_placeholder_ranges(num_items=num_images,
70-
item_size=image_feature_size)
71-
}
72-
73-
74-
def dummy_image_for_siglip(
75-
hf_config: SiglipVisionConfig,
76-
num_images: int,
77-
*,
78-
image_width_override: Optional[int] = None,
79-
image_height_override: Optional[int] = None,
80-
):
81-
width = height = hf_config.image_size
82-
if image_width_override is not None:
83-
width = image_width_override
84-
if image_height_override is not None:
85-
height = image_height_override
86-
87-
image = Image.new("RGB", (width, height), color=0)
88-
return {"image": image if num_images == 1 else [image] * num_images}
89-
90-
9126
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
9227

9328
def get_num_image_tokens(
@@ -96,10 +31,10 @@ def get_num_image_tokens(
9631
image_width: int,
9732
image_height: int,
9833
) -> int:
99-
return get_siglip_image_feature_size(self.vision_config)
34+
return self.get_patch_grid_length()**2
10035

10136
def get_max_image_tokens(self) -> int:
102-
return get_max_siglip_image_tokens(self.vision_config)
37+
return self.get_patch_grid_length()**2
10338

10439
def get_image_size(self) -> int:
10540
return self.vision_config.image_size
@@ -108,10 +43,8 @@ def get_patch_size(self) -> int:
10843
return self.vision_config.patch_size
10944

11045
def get_patch_grid_length(self) -> int:
111-
return get_siglip_patch_grid_length(
112-
image_size=self.vision_config.image_size,
113-
patch_size=self.vision_config.patch_size,
114-
)
46+
image_size, patch_size = self.get_image_size(), self.get_patch_size()
47+
return image_size // patch_size
11548

11649

11750
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa

0 commit comments

Comments
 (0)