Skip to content

Commit e5949e5

Browse files
chenxi-yangChenxi Yang
andauthored
Remove index_put from MM embeddings merging (#22105)
Co-authored-by: Chenxi Yang <cxyang@meta.com>
1 parent 49bcd89 commit e5949e5

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

vllm/model_executor/models/utils.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def merge_multimodal_embeddings_from_map(
393393
inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors,
394394
placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor:
395395
"""
396-
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
396+
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
397397
placeholder map .
398398
399399
Note:
@@ -418,17 +418,23 @@ def _merge_multimodal_embeddings(
418418
Note:
419419
This updates ``inputs_embeds`` in place.
420420
"""
421-
num_expected_tokens = is_multimodal.sum().item()
422-
assert isinstance(num_expected_tokens, int)
423-
424421
flattened = _flatten_embeddings(multimodal_embeddings)
425-
if flattened.shape[0] != num_expected_tokens:
426-
expr = _embedding_count_expression(multimodal_embeddings)
427-
raise ValueError(
428-
f"Attempted to assign {expr} = {flattened.shape[0]} "
429-
f"multimodal tokens to {num_expected_tokens} placeholders")
422+
try:
423+
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
424+
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), flattened)
425+
except RuntimeError as e:
426+
num_expected_tokens = is_multimodal.sum().item()
427+
assert isinstance(num_expected_tokens, int)
428+
429+
if flattened.shape[0] != num_expected_tokens:
430+
expr = _embedding_count_expression(multimodal_embeddings)
431+
raise ValueError(
432+
f"Attempted to assign {expr} = {flattened.shape[0]} "
433+
f"multimodal tokens to {num_expected_tokens} placeholders"
434+
) from e
435+
else:
436+
raise ValueError("Error during masked scatter operation") from e
430437

431-
inputs_embeds[is_multimodal] = flattened
432438
return inputs_embeds
433439

434440

@@ -478,11 +484,11 @@ def merge_multimodal_embeddings(
478484
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
479485
positions in ``inputs_embeds`` corresponding to placeholder tokens in
480486
``input_ids``.
481-
482-
``placeholder_token_id`` can be a list of token ids (e.g, token ids
483-
of img_start, img_break, and img_end tokens) when needed: This means
484-
the order of these tokens in the ``input_ids`` MUST MATCH the order of
485-
their embeddings in ``multimodal_embeddings`` since we need to
487+
488+
``placeholder_token_id`` can be a list of token ids (e.g, token ids
489+
of img_start, img_break, and img_end tokens) when needed: This means
490+
the order of these tokens in the ``input_ids`` MUST MATCH the order of
491+
their embeddings in ``multimodal_embeddings`` since we need to
486492
slice-merge instead of individually scattering.
487493
488494
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
@@ -491,9 +497,9 @@ def merge_multimodal_embeddings(
491497
- I is image embedding token
492498
- B is image break token
493499
- E is image end token.
494-
495-
Then the image embeddings (that correspond to I's) from vision encoder
496-
must be padded with embeddings of S, B, and E in the same order of
500+
501+
Then the image embeddings (that correspond to I's) from vision encoder
502+
must be padded with embeddings of S, B, and E in the same order of
497503
input_ids for a correct embedding merge.
498504
499505
Note:

0 commit comments

Comments
 (0)