|
39 | 39 | from vllm.model_executor.pooling_metadata import PoolingMetadata |
40 | 40 | from vllm.model_executor.sampling_metadata import SamplingMetadata |
41 | 41 | from vllm.multimodal import MULTIMODAL_REGISTRY |
| 42 | +from vllm.multimodal.base import NestedTensors, PlaceholderRange |
42 | 43 | from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token |
43 | 44 | from vllm.sequence import IntermediateTensors, PoolerOutput |
44 | 45 | from vllm.utils import is_list_of |
@@ -500,23 +501,29 @@ def input_processor_for_phi3v(ctx: InputContext, |
500 | 501 |
|
501 | 502 | # TODO: Move this to utils or integrate with clip. |
502 | 503 | new_token_ids: List[int] = [] |
| 504 | + placeholder_ranges: List[PlaceholderRange] = [] |
503 | 505 | placeholder_idx = 0 |
504 | 506 | while merged_token_ids: |
505 | 507 | token_id = merged_token_ids.pop(0) |
506 | 508 | if token_id == _IMAGE_TOKEN_ID: |
507 | | - new_token_ids.extend( |
508 | | - repeat_and_pad_token( |
509 | | - _IMAGE_TOKEN_ID, |
510 | | - repeat_count=image_feature_size[placeholder_idx], |
511 | | - )) |
| 509 | + replacement_ids = repeat_and_pad_token( |
| 510 | + _IMAGE_TOKEN_ID, |
| 511 | + repeat_count=image_feature_size[placeholder_idx], |
| 512 | + ) |
| 513 | + placeholder_ranges.append({ |
| 514 | + "offset": len(new_token_ids), |
| 515 | + "length": len(replacement_ids) |
| 516 | + }) |
| 517 | + new_token_ids.extend(replacement_ids) |
512 | 518 | placeholder_idx += 1 |
513 | 519 | else: |
514 | 520 | new_token_ids.append(token_id) |
515 | 521 |
|
516 | 522 | # NOTE: Create a defensive copy of the original inputs |
517 | 523 | return token_inputs(prompt_token_ids=new_token_ids, |
518 | 524 | prompt=new_prompt, |
519 | | - multi_modal_data=multi_modal_data) |
| 525 | + multi_modal_data=multi_modal_data, |
| 526 | + multi_modal_placeholders={"image": placeholder_ranges}) |
520 | 527 |
|
521 | 528 |
|
522 | 529 | @MULTIMODAL_REGISTRY.register_image_input_mapper() |
@@ -669,32 +676,42 @@ def _process_image_input( |
669 | 676 |
|
670 | 677 | return image_embeds |
671 | 678 |
|
| 679 | + def process_mm_inputs(self, **kwargs): |
| 680 | + image_input = self._parse_and_validate_image_input(**kwargs) |
| 681 | + if image_input is None: |
| 682 | + return None |
| 683 | + vision_embeddings = self._process_image_input(image_input) |
| 684 | + return vision_embeddings |
| 685 | + |
| 686 | + def get_input_embeddings( |
| 687 | + self, |
| 688 | + input_ids: torch.Tensor, |
| 689 | + vision_embeddings: Optional[NestedTensors] = None, |
| 690 | + ) -> torch.Tensor: |
| 691 | + inputs_embeds = self.embed_tokens(input_ids) |
| 692 | + if vision_embeddings is not None: |
| 693 | + inputs_embeds = merge_multimodal_embeddings( |
| 694 | + input_ids, inputs_embeds, vision_embeddings, |
| 695 | + self.image_token_id) |
| 696 | + return inputs_embeds |
| 697 | + |
672 | 698 | def forward(self, |
673 | 699 | input_ids: torch.Tensor, |
674 | 700 | positions: torch.Tensor, |
675 | 701 | kv_caches: List[torch.Tensor], |
676 | 702 | attn_metadata: AttentionMetadata, |
677 | 703 | intermediate_tensors: Optional[IntermediateTensors] = None, |
| 704 | + inputs_embeds: Optional[torch.Tensor] = None, |
678 | 705 | **kwargs: object): |
679 | 706 | if intermediate_tensors is not None: |
680 | 707 | inputs_embeds = None |
681 | | - else: |
682 | | - image_input = self._parse_and_validate_image_input(**kwargs) |
683 | | - |
684 | | - if image_input is not None: |
685 | | - vision_embeddings = self._process_image_input(image_input) |
686 | | - inputs_embeds = self.embed_tokens(input_ids) |
687 | | - inputs_embeds = merge_multimodal_embeddings( |
688 | | - input_ids, inputs_embeds, vision_embeddings, |
689 | | - self.image_token_id) |
690 | | - else: |
691 | | - inputs_embeds = self.language_model.model.embed_tokens( |
692 | | - input_ids) |
693 | | - |
694 | | - # always pass the input via `inputs_embeds` |
695 | | - # to make sure the computation graph is consistent |
696 | | - # for `torch.compile` integration |
697 | | - input_ids = None |
| 708 | + elif inputs_embeds is None: |
| 709 | + vision_embeddings = self.process_mm_inputs(**kwargs) |
| 710 | + # always pass the input via `inputs_embeds` |
| 711 | + # to make sure the computation graph is consistent |
| 712 | + inputs_embeds = self.get_input_embeddings(input_ids, |
| 713 | + vision_embeddings) |
| 714 | + input_ids = None |
698 | 715 |
|
699 | 716 | hidden_states = self.language_model.model(input_ids, |
700 | 717 | positions, |
|
0 commit comments