@@ -393,7 +393,7 @@ def merge_multimodal_embeddings_from_map(
393393 inputs_embeds : torch .Tensor , multimodal_embeddings : NestedTensors ,
394394 placeholder_map : MultiModalPlaceholderMap .IndexMap ) -> torch .Tensor :
395395 """
396- Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
396+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
397397 placeholder map .
398398
399399 Note:
@@ -418,17 +418,23 @@ def _merge_multimodal_embeddings(
418418 Note:
419419 This updates ``inputs_embeds`` in place.
420420 """
421- num_expected_tokens = is_multimodal .sum ().item ()
422- assert isinstance (num_expected_tokens , int )
423-
424421 flattened = _flatten_embeddings (multimodal_embeddings )
425- if flattened .shape [0 ] != num_expected_tokens :
426- expr = _embedding_count_expression (multimodal_embeddings )
427- raise ValueError (
428- f"Attempted to assign { expr } = { flattened .shape [0 ]} "
429- f"multimodal tokens to { num_expected_tokens } placeholders" )
422+ try :
423+ # This is equivalent to: inputs_embeds[is_multimodal] = flattened.
424+ inputs_embeds .masked_scatter_ (is_multimodal .unsqueeze (- 1 ), flattened )
425+ except RuntimeError as e :
426+ num_expected_tokens = is_multimodal .sum ().item ()
427+ assert isinstance (num_expected_tokens , int )
428+
429+ if flattened .shape [0 ] != num_expected_tokens :
430+ expr = _embedding_count_expression (multimodal_embeddings )
431+ raise ValueError (
432+ f"Attempted to assign { expr } = { flattened .shape [0 ]} "
433+ f"multimodal tokens to { num_expected_tokens } placeholders"
434+ ) from e
435+ else :
436+ raise ValueError ("Error during masked scatter operation" ) from e
430437
431- inputs_embeds [is_multimodal ] = flattened
432438 return inputs_embeds
433439
434440
@@ -478,11 +484,11 @@ def merge_multimodal_embeddings(
478484 Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
479485 positions in ``inputs_embeds`` corresponding to placeholder tokens in
480486 ``input_ids``.
481-
482- ``placeholder_token_id`` can be a list of token ids (e.g, token ids
483- of img_start, img_break, and img_end tokens) when needed: This means
484- the order of these tokens in the ``input_ids`` MUST MATCH the order of
485- their embeddings in ``multimodal_embeddings`` since we need to
487+
488+ ``placeholder_token_id`` can be a list of token ids (e.g, token ids
489+ of img_start, img_break, and img_end tokens) when needed: This means
490+ the order of these tokens in the ``input_ids`` MUST MATCH the order of
491+ their embeddings in ``multimodal_embeddings`` since we need to
486492 slice-merge instead of individually scattering.
487493
488494 For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
@@ -491,9 +497,9 @@ def merge_multimodal_embeddings(
491497 - I is image embedding token
492498 - B is image break token
493499 - E is image end token.
494-
495- Then the image embeddings (that correspond to I's) from vision encoder
496- must be padded with embeddings of S, B, and E in the same order of
500+
501+ Then the image embeddings (that correspond to I's) from vision encoder
502+ must be padded with embeddings of S, B, and E in the same order of
497503 input_ids for a correct embedding merge.
498504
499505 Note:
0 commit comments