Skip to content

Commit 5994430

Browse files
[Misc] Remove redundant num_embeds (#15443)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent a9e879b commit 5994430

File tree

5 files changed

+25
-64
lines changed

5 files changed

+25
-64
lines changed

vllm/model_executor/models/gemma3_mm.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@ class Gemma3ImagePixelInputs(TypedDict):
6363
Shape: `(batch_size, num_images, num_embeds)`
6464
"""
6565

66-
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
67-
"""Shape: `(batch_size, num_images)`"""
68-
6966

7067
Gemma3ImageInputs = Gemma3ImagePixelInputs
7168

@@ -317,11 +314,6 @@ def _call_hf_processor(
317314
tokenizer.encode(image_repl, add_special_tokens=False)
318315
for image_repl in image_repl_features
319316
]
320-
num_embeds = [
321-
len(image_repl_feature_tokens)
322-
for image_repl_feature_tokens in image_repls_feature_tokens
323-
]
324-
processed_outputs["num_embeds"] = torch.tensor(num_embeds)
325317

326318
vocab = tokenizer.get_vocab()
327319
image_token_id = vocab[tokenizer.image_token]
@@ -354,7 +346,6 @@ def _get_mm_fields_config(
354346
"image", num_crops + 1),
355347
num_crops=MultiModalFieldConfig.batched("image"),
356348
embed_is_patch=MultiModalFieldConfig.batched("image"),
357-
num_embeds=MultiModalFieldConfig.batched("image"),
358349
)
359350

360351
def _get_prompt_updates(
@@ -583,7 +574,6 @@ def _parse_and_validate_image_input(
583574
pixel_values = kwargs.pop("pixel_values", None)
584575
num_crops = kwargs.pop("num_crops", None)
585576
embed_is_patch = kwargs.pop("embed_is_patch", None)
586-
num_embeds = kwargs.pop("num_embeds", None)
587577
image_embeds = kwargs.pop("image_embeds", None)
588578
assert image_embeds is None, "Gemma3 does not support image_embeds."
589579
if pixel_values is None:
@@ -601,10 +591,6 @@ def _parse_and_validate_image_input(
601591
raise ValueError("Incorrect type of embed_is_patch. "
602592
f"Got type: {type(embed_is_patch)}")
603593

604-
if not isinstance(num_embeds, (torch.Tensor, list)):
605-
raise ValueError("Incorrect type of num_embeds. "
606-
f"Got type: {type(num_embeds)}")
607-
608594
pixel_values = flatten_bn(pixel_values, concat=True)
609595
num_crops = flatten_bn(num_crops, concat=True)
610596

@@ -613,7 +599,6 @@ def _parse_and_validate_image_input(
613599
pixel_values=self._validate_pixel_values(pixel_values),
614600
num_patches=num_crops + 1,
615601
embed_is_patch=embed_is_patch,
616-
num_embeds=num_embeds,
617602
)
618603

619604
def _image_pixels_to_features(
@@ -656,7 +641,6 @@ def get_multimodal_embeddings(
656641
return flatten_2d_lists(
657642
scatter_patch_features(*args) for args in zip(
658643
image_features,
659-
image_input["num_embeds"],
660644
image_input["embed_is_patch"],
661645
))
662646

vllm/model_executor/models/internvl.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ class InternVLImagePixelInputs(TypedDict):
6969
Shape: `(batch_size, num_images, num_embeds)`
7070
"""
7171

72-
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
73-
"""Shape: `(batch_size, num_images)`"""
74-
7572

7673
class InternVLImageEmbeddingInputs(TypedDict):
7774
type: Literal["image_embeds"]
@@ -426,7 +423,6 @@ def __call__(
426423
tokenizer = self.tokenizer
427424
image_token_id = self.image_token_id
428425

429-
num_embeds = list[int]()
430426
embed_is_patch = list[torch.Tensor]()
431427

432428
for pixel_values in pixel_values_lst:
@@ -438,11 +434,9 @@ def __call__(
438434
add_special_tokens=False)
439435

440436
text = [t.replace('<image>', image_repl.full, 1) for t in text]
441-
num_embeds.append(len(feature_tokens))
442437
embed_is_patch.append(
443438
torch.tensor(feature_tokens) == image_token_id)
444439

445-
image_inputs["num_embeds"] = torch.tensor(num_embeds)
446440
image_inputs["embed_is_patch"] = embed_is_patch
447441

448442
text_inputs = self.tokenizer(text)
@@ -607,7 +601,6 @@ def _get_mm_fields_config(
607601
"image", image_num_patches),
608602
image_num_patches=MultiModalFieldConfig.batched("image"),
609603
embed_is_patch=MultiModalFieldConfig.batched("image"),
610-
num_embeds=MultiModalFieldConfig.batched("image"),
611604
image_embeds=MultiModalFieldConfig.batched("image"),
612605
image_token_id=MultiModalFieldConfig.shared("image", num_images),
613606
)
@@ -840,7 +833,6 @@ def _parse_and_validate_image_input(
840833
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
841834
image_num_patches = kwargs.pop("image_num_patches", None)
842835
embed_is_patch = kwargs.pop("embed_is_patch", None)
843-
num_embeds = kwargs.pop("num_embeds", None)
844836
image_embeds = kwargs.pop("image_embeds", None)
845837

846838
if pixel_values_flat is None and image_embeds is None:
@@ -873,10 +865,6 @@ def _parse_and_validate_image_input(
873865
raise ValueError("Incorrect type of embed_is_patch. "
874866
f"Got type: {type(embed_is_patch)}")
875867

876-
if not isinstance(num_embeds, (torch.Tensor, list)):
877-
raise ValueError("Incorrect type of num_embeds. "
878-
f"Got type: {type(num_embeds)}")
879-
880868
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
881869
image_num_patches = flatten_bn(image_num_patches, concat=True)
882870

@@ -886,7 +874,6 @@ def _parse_and_validate_image_input(
886874
pixel_values_flat),
887875
num_patches=image_num_patches,
888876
embed_is_patch=embed_is_patch,
889-
num_embeds=num_embeds,
890877
)
891878

892879
raise AssertionError("This line should be unreachable.")
@@ -941,7 +928,6 @@ def get_multimodal_embeddings(
941928
return flatten_2d_lists(
942929
scatter_patch_features(*args) for args in zip(
943930
image_features,
944-
image_input["num_embeds"],
945931
image_input["embed_is_patch"],
946932
))
947933

vllm/model_executor/models/llava.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,6 @@ class PixtralHFImagePixelInputs(TypedDict):
7676
Shape: `(batch_size, num_images, num_embeds)`
7777
"""
7878

79-
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
80-
"""Shape: `(batch_size, num_images)`"""
81-
8279

8380
class LlavaImageEmbeddingInputs(TypedDict):
8481
type: Literal["image_embeds"]
@@ -358,15 +355,10 @@ def _call_hf_processor(
358355
image_height=pixel_value.shape[-2],
359356
) for pixel_value in processed_outputs["pixel_values"]
360357
]
361-
num_embeds = torch.tensor([(ncols + 1) * nrows
362-
for ncols, nrows in tile_sizes])
363-
# Each image may result to masks of different sizes, so we need to
364-
# later use `num_embeds` to get per-image masks.
365358
embed_is_patch = [
366359
torch.tensor(([True] * ncols + [False]) * nrows)
367360
for ncols, nrows in tile_sizes
368361
]
369-
processed_outputs["num_embeds"] = num_embeds
370362
processed_outputs["embed_is_patch"] = embed_is_patch
371363

372364
return processed_outputs
@@ -378,7 +370,6 @@ def _get_mm_fields_config(
378370
) -> Mapping[str, MultiModalFieldConfig]:
379371
return dict(
380372
pixel_values=MultiModalFieldConfig.batched("image"),
381-
num_embeds=MultiModalFieldConfig.batched("image"),
382373
embed_is_patch=MultiModalFieldConfig.batched("image"),
383374
image_embeds=MultiModalFieldConfig.batched("image"),
384375
)
@@ -627,16 +618,10 @@ def _parse_and_validate_image_input(
627618
raise ValueError("Incorrect type of embed_is_patch. "
628619
f"Got type: {type(embed_is_patch)}")
629620

630-
num_embeds = kwargs.pop("num_embeds")
631-
if not isinstance(num_embeds, (torch.Tensor, list)):
632-
raise ValueError("Incorrect type of num_embeds. "
633-
f"Got type: {type(num_embeds)}")
634-
635621
return PixtralHFImagePixelInputs(
636622
type="pixel_values_pixtral",
637623
pixel_values=flatten_bn(pixel_values),
638624
embed_is_patch=embed_is_patch,
639-
num_embeds=num_embeds,
640625
)
641626

642627
return LlavaImagePixelInputs(
@@ -738,7 +723,6 @@ def get_multimodal_embeddings(
738723
return flatten_2d_lists(
739724
scatter_patch_features(*args) for args in zip(
740725
vision_embeddings,
741-
image_input["num_embeds"],
742726
image_input["embed_is_patch"],
743727
))
744728

vllm/model_executor/models/pixtral.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,6 @@ class PixtralImagePixelInputs(TypedDict):
7777
Shape: `(batch_size, num_images, num_embeds)`
7878
"""
7979

80-
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
81-
"""Shape: `(batch_size, num_images)`"""
82-
8380

8481
class PixtralProcessorAdapter:
8582
"""
@@ -153,7 +150,6 @@ def __call__(
153150
images_processed = list[torch.Tensor]()
154151
images_tokens = list[torch.Tensor]()
155152
images_embed_is_patch = list[torch.Tensor]()
156-
images_num_embeds = list[int]()
157153

158154
for image in images:
159155
image_inputs = self.image_processor(ImageChunk(image=image))
@@ -163,13 +159,11 @@ def __call__(
163159
images_processed.append(image_processed)
164160
images_tokens.append(image_tokens)
165161
images_embed_is_patch.append(image_tokens == image_token_id)
166-
images_num_embeds.append(len(image_tokens))
167162

168163
return {
169164
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
170165
"images": images_processed,
171166
"embed_is_patch": images_embed_is_patch,
172-
"num_embeds": torch.tensor(images_num_embeds),
173167
}
174168

175169

@@ -273,7 +267,6 @@ def _get_mm_fields_config(
273267
return dict(
274268
images=MultiModalFieldConfig.batched("image"),
275269
embed_is_patch=MultiModalFieldConfig.batched("image"),
276-
num_embeds=MultiModalFieldConfig.batched("image"),
277270
)
278271

279272
def _get_prompt_updates(
@@ -394,16 +387,10 @@ def _parse_and_validate_image_input(
394387
raise ValueError("Incorrect type of embed_is_patch. "
395388
f"Got type: {type(embed_is_patch)}")
396389

397-
num_embeds = kwargs.pop("num_embeds")
398-
if not isinstance(num_embeds, (torch.Tensor, list)):
399-
raise ValueError("Incorrect type of num_embeds. "
400-
f"Got type: {type(num_embeds)}")
401-
402390
return PixtralImagePixelInputs(
403391
type="pixel_values",
404392
images=flatten_bn(images),
405393
embed_is_patch=embed_is_patch,
406-
num_embeds=num_embeds,
407394
)
408395

409396
def _process_image_input(
@@ -447,7 +434,6 @@ def get_multimodal_embeddings(
447434
return flatten_2d_lists(
448435
scatter_patch_features(*args) for args in zip(
449436
image_features,
450-
image_input["num_embeds"],
451437
image_input["embed_is_patch"],
452438
))
453439

vllm/model_executor/models/vision.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def resolve_visual_encoder_outputs(
155155

156156
def scatter_patch_features(
157157
features: torch.Tensor,
158-
num_embeds: torch.Tensor,
159158
embed_is_patch: torch.Tensor,
160159
) -> tuple[torch.Tensor, ...]:
161160
"""
@@ -168,13 +167,35 @@ def scatter_patch_features(
168167
Args:
169168
features: The patch features, concatenated across each image.
170169
Shape: `(num_patch, feature_depth)`
171-
num_embeds: The number of image embeddings for each image.
172-
Shape: `(num_images,)`
173170
embed_is_patch: A boolean mask indicating which image embeddings
174171
correspond to patch tokens for each image.
175172
Shape: `(num_images, num_embeds)`
173+
174+
Note:
175+
The original code only considers patch tokens as feature
176+
tokens, but our processor considers all image-related tokens
177+
as feature tokens because the feature tokens need to be
178+
consecutive in `input_ids`.
179+
180+
Example:
181+
A simplified example for one image:
182+
183+
.. code-block::
184+
185+
Embedding tokens (from HF processor):
186+
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
187+
188+
embed_is_patch (from HF processor):
189+
[ False True True False True True False False ]
190+
191+
Encoder outputs (from model):
192+
[ p1 p2 p3 p4 ]
193+
194+
The resulting embedding tensor is:
195+
[ nan p1 p2 nan p3 p4 nan nan ]
176196
"""
177-
num_embeds_per_image: list[int] = num_embeds.tolist()
197+
num_images, num_embeds = embed_is_patch.shape
198+
num_embeds_per_image = [num_embeds] * num_images
178199

179200
embeds_flat = features.new_full(
180201
(sum(num_embeds_per_image), features.shape[-1]),

0 commit comments

Comments
 (0)