From 4179432a0d59fff7959ebda62f9b4ad4eaad4568 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sat, 2 Nov 2024 01:55:20 -0700 Subject: [PATCH 01/12] update Signed-off-by: Roger Wang --- vllm/model_executor/models/blip2.py | 27 +++++++++++++++----------- vllm/model_executor/models/fuyu.py | 24 ++++++++++++++--------- vllm/model_executor/models/llava.py | 18 +++++++---------- vllm/model_executor/models/ultravox.py | 11 +++++++---- vllm/multimodal/utils.py | 17 +++++++++++++--- 5 files changed, 59 insertions(+), 38 deletions(-) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index db1f92649bd49..9d8eca3b78a68 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -23,7 +23,7 @@ get_max_blip_image_tokens) from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, init_vllm_registered_model, - merge_multimodal_embeddings) + merge_multimodal_embeddings_from_map) # We use this internally as placeholders since there is no image token # defined on the HuggingFace repo @@ -472,9 +472,13 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs): if new_prompt is not None: new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt + ranges = consecutive_placeholder_ranges(num_items=1, + item_size=image_feature_size) + return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": ranges}) @MULTIMODAL_REGISTRY.register_image_input_mapper() @@ -651,24 +655,25 @@ def forward( :class:`Blip2ImageInputs` """ if intermediate_tensors is not None: - input_ids = None inputs_embeds = None else: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: vision_embeddings = self._process_image_input(image_input) inputs_embeds = self.language_model.model.get_input_embeddings( input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - BLIP2_IMAGE_TOKEN_ID) - - input_ids = None - else: - inputs_embeds = None + merge_multimodal_embeddings_from_map( + inputs_embeds, vision_embeddings, + attn_metadata.multi_modal_placeholder_index_maps["image"]) + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None hidden_states = self.language_model.model( input_ids, positions, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 0de590d1d8372..db4d42666b7d8 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -44,7 +44,7 @@ from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP -from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings +from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings_from_map # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 @@ -199,9 +199,12 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs): new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[ 1:] + boa_token + ranges = consecutive_placeholder_ranges(num_items=len(image_list), + item_size=image_patches.size(1)) return token_inputs(prompt=new_prompt, prompt_token_ids=new_prompt_token_ids, - multi_modal_data=new_multi_modal_data) + multi_modal_data=new_multi_modal_data, + multi_modal_placeholders={"image": ranges}) def input_mapper_for_fuyu(ctx: InputContext, data: object): @@ -313,22 +316,25 @@ def forward( **kwargs: object, ): if intermediate_tensors is not None: - input_ids = None inputs_embeds = None else: - image_input = self._parse_and_validate_image_input(**kwargs) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: vision_embeddings = self._process_image_input(image_input) inputs_embeds = self.language_model.model.embed_tokens( input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.image_token_id) - else: - inputs_embeds = None + merge_multimodal_embeddings_from_map( + inputs_embeds, vision_embeddings, + attn_metadata.multi_modal_placeholder_index_maps["image"]) + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None hidden_states = self.language_model( input_ids=input_ids, positions=positions, diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 7fbd59ebd98fd..dbe14ebdd4455 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -32,7 +32,7 @@ dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings) + merge_multimodal_embeddings_from_map) class LlavaImagePixelInputs(TypedDict): @@ -496,24 +496,20 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None else: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) - else: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) + merge_multimodal_embeddings_from_map( + inputs_embeds, vision_embeddings, + attn_metadata.multi_modal_placeholder_index_maps["image"]) # always pass the input via `inputs_embeds` # to make sure the computation graph is consistent # for `torch.compile` integration input_ids = None - hidden_states = self.language_model.model(input_ids, positions, kv_caches, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 749750fc9c16e..506cb066db48f 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -467,9 +467,11 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, audio_features: A batch of audio inputs [B, N, 80, M]. """ if intermediate_tensors is not None: - input_ids = None inputs_embeds = None else: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is not None: audio_embeddings = self._process_audio_input(audio_input) @@ -479,10 +481,11 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, merge_multimodal_embeddings_from_map( inputs_embeds, audio_embeddings, attn_metadata.multi_modal_placeholder_index_maps["audio"]) - input_ids = None - else: - inputs_embeds = None + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index c5ff552e06099..2dc042bf0d941 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -329,10 +329,21 @@ def repeat_and_pad_placeholder_tokens( def consecutive_placeholder_ranges(num_items: int, - item_size: int) -> List[PlaceholderRange]: - """Returns a list of consecutive PlaceholderRanges of a fixed size""" + item_size: int, + initial_offset: int = 0, + ) -> List[PlaceholderRange]: + """Returns a list of consecutive PlaceholderRanges of a fixed size + + Args: + num_items: The number of items + item_size: The size of each item + initial_offset: The initial offset/index of the first item + + Returns: + A list of PlaceholderRange objects + """ return [ - PlaceholderRange(offset=i * item_size, length=item_size) + PlaceholderRange(offset=initial_offset + i * item_size, length=item_size) for i in range(num_items) ] From 343b467ad39185e0f9e8dc60a618dd868294a064 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sat, 2 Nov 2024 02:11:20 -0700 Subject: [PATCH 02/12] format Signed-off-by: Roger Wang --- vllm/model_executor/models/blip2.py | 4 ++-- vllm/model_executor/models/fuyu.py | 5 +++-- vllm/model_executor/models/llava.py | 2 +- vllm/multimodal/utils.py | 13 +++++++------ 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 9d8eca3b78a68..9723c20942d85 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -658,8 +658,8 @@ def forward( inputs_embeds = None else: inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - + input_ids) + image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: vision_embeddings = self._process_image_input(image_input) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index db4d42666b7d8..6d5177ad39782 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -44,7 +44,8 @@ from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP -from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings_from_map +from .utils import (AutoWeightsLoader, flatten_bn, + merge_multimodal_embeddings_from_map) # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 @@ -319,7 +320,7 @@ def forward( inputs_embeds = None else: inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) + input_ids) image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index dbe14ebdd4455..de270234193f4 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -497,7 +497,7 @@ def forward( inputs_embeds = None else: inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) + input_ids) image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 2dc042bf0d941..ad1b90327c1ff 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -328,10 +328,11 @@ def repeat_and_pad_placeholder_tokens( return new_prompt, new_token_ids, placeholder_ranges -def consecutive_placeholder_ranges(num_items: int, - item_size: int, - initial_offset: int = 0, - ) -> List[PlaceholderRange]: +def consecutive_placeholder_ranges( + num_items: int, + item_size: int, + initial_offset: int = 0, +) -> List[PlaceholderRange]: """Returns a list of consecutive PlaceholderRanges of a fixed size Args: @@ -344,6 +345,6 @@ def consecutive_placeholder_ranges(num_items: int, """ return [ - PlaceholderRange(offset=initial_offset + i * item_size, length=item_size) - for i in range(num_items) + PlaceholderRange(offset=initial_offset + i * item_size, + length=item_size) for i in range(num_items) ] From 0c472f67bac33a6349893c430c73e5a8ed6937f8 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Mon, 4 Nov 2024 00:48:48 -0800 Subject: [PATCH 03/12] fix assignment order Signed-off-by: Roger Wang --- vllm/config.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 17e9b1c100498..1e222b99ce528 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1062,6 +1062,21 @@ def __init__(self, send_delta_data: bool = False, policy: str = "fcfs") -> None: if max_num_batched_tokens is None: + + if is_multimodal_model: + # The value needs to be at least the number of multimodal tokens + max_num_batched_tokens = max( + max_num_batched_tokens, + _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + + if task == "embedding": + # For embedding, choose specific value for higher throughput + max_num_batched_tokens = max( + max_num_batched_tokens, + _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + if enable_chunked_prefill: if num_scheduler_steps > 1: # Multi-step Chunked-Prefill doesn't allow prompt-chunking @@ -1078,19 +1093,6 @@ def __init__(self, # for higher throughput. max_num_batched_tokens = max(max_model_len, 2048) - if task == "embedding": - # For embedding, choose specific value for higher throughput - max_num_batched_tokens = max( - max_num_batched_tokens, - _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS, - ) - if is_multimodal_model: - # The value needs to be at least the number of multimodal tokens - max_num_batched_tokens = max( - max_num_batched_tokens, - _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, - ) - self.max_num_batched_tokens = max_num_batched_tokens if enable_chunked_prefill: From 3998f9d16b17ae2a753f261aaa1592e357bff01a Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Mon, 4 Nov 2024 00:53:58 -0800 Subject: [PATCH 04/12] update Signed-off-by: Roger Wang --- vllm/config.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 1e222b99ce528..376e00586711e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1062,21 +1062,6 @@ def __init__(self, send_delta_data: bool = False, policy: str = "fcfs") -> None: if max_num_batched_tokens is None: - - if is_multimodal_model: - # The value needs to be at least the number of multimodal tokens - max_num_batched_tokens = max( - max_num_batched_tokens, - _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, - ) - - if task == "embedding": - # For embedding, choose specific value for higher throughput - max_num_batched_tokens = max( - max_num_batched_tokens, - _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS, - ) - if enable_chunked_prefill: if num_scheduler_steps > 1: # Multi-step Chunked-Prefill doesn't allow prompt-chunking @@ -1093,6 +1078,21 @@ def __init__(self, # for higher throughput. max_num_batched_tokens = max(max_model_len, 2048) + if is_multimodal_model: + # The value needs to be at least the number of multimodal + # tokens if chunked prefill is not enabled. + max_num_batched_tokens = max( + max_num_batched_tokens, + _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + + if task == "embedding": + # For embedding, choose specific value for higher throughput + max_num_batched_tokens = max( + max_num_batched_tokens, + _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + self.max_num_batched_tokens = max_num_batched_tokens if enable_chunked_prefill: From 0ea320964a05f59b26e719d3b2deab58157c49e3 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 6 Nov 2024 10:50:56 -0800 Subject: [PATCH 05/12] refactor Signed-off-by: Roger Wang --- vllm/model_executor/models/blip2.py | 49 ++++++++++++++--------- vllm/model_executor/models/fuyu.py | 62 ++++++++++++++++------------- 2 files changed, 65 insertions(+), 46 deletions(-) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 9723c20942d85..6cbcd2025ec5b 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -23,7 +23,7 @@ get_max_blip_image_tokens) from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, init_vllm_registered_model, - merge_multimodal_embeddings_from_map) + merge_multimodal_embeddings) # We use this internally as placeholders since there is no image token # defined on the HuggingFace repo @@ -615,6 +615,26 @@ def _process_image_input(self, return self.language_projection(query_output) + def process_mm_inputs(self, **kwargs: object) -> Optional[torch.Tensor]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_inputs_embeds( + self, input_ids: torch.Tensor, + vision_embeddings: Optional[torch.Tensor]) -> torch.Tensor: + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + if vision_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + multimodal_embeddings=vision_embeddings, + placeholder_token_id=BLIP2_IMAGE_TOKEN_ID) + + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -622,6 +642,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[SamplerOutput, IntermediateTensors]: """Run forward pass for BLIP-2. @@ -654,26 +675,16 @@ def forward( See also: :class:`Blip2ImageInputs` """ + if intermediate_tensors is not None: inputs_embeds = None - else: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - merge_multimodal_embeddings_from_map( - inputs_embeds, vision_embeddings, - attn_metadata.multi_modal_placeholder_index_maps["image"]) - - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - # for `torch.compile` integration - input_ids = None + + elif inputs_embeds is None: + vision_embeddings = self.process_mm_inputs(**kwargs) + inputs_embeds = self.get_inputs_embeds( + input_ids=input_ids, vision_embeddings=vision_embeddings) + input_ids = None + hidden_states = self.language_model.model( input_ids, positions, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 6d5177ad39782..85985b3cdcb30 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -44,8 +44,7 @@ from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, flatten_bn, - merge_multimodal_embeddings_from_map) +from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 @@ -307,6 +306,26 @@ def _process_image_input( vision_embeddings, _ = self.vision_embed_tokens(image_input["data"]) return vision_embeddings + def process_mm_inputs(self, **kwargs: object) -> Optional[torch.Tensor]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_inputs_embeds( + self, input_ids: torch.Tensor, + vision_embeddings: Optional[torch.Tensor]) -> torch.Tensor: + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + if vision_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + multimodal_embeddings=vision_embeddings, + placeholder_token_id=self.image_token_id) + + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -314,36 +333,25 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ): if intermediate_tensors is not None: inputs_embeds = None - else: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.embed_tokens( - input_ids) - - merge_multimodal_embeddings_from_map( - inputs_embeds, vision_embeddings, - attn_metadata.multi_modal_placeholder_index_maps["image"]) - - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - # for `torch.compile` integration - input_ids = None - hidden_states = self.language_model( - input_ids=input_ids, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, + elif inputs_embeds is None: + vision_embeddings = self.process_mm_inputs(**kwargs) + inputs_embeds = self.get_inputs_embeds( + input_ids=input_ids, vision_embeddings=vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model( + input_ids, + positions, + kv_caches, + attn_metadata, intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) + inputs_embeds=inputs_embeds) + return hidden_states def compute_logits( From 68aebb3a86e524127f90493d3d68066d8fe5f3ca Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 6 Nov 2024 11:37:39 -0800 Subject: [PATCH 06/12] revert llava changes Signed-off-by: Roger Wang --- vllm/model_executor/models/llava.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index de270234193f4..7fbd59ebd98fd 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -32,7 +32,7 @@ dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings_from_map) + merge_multimodal_embeddings) class LlavaImagePixelInputs(TypedDict): @@ -496,20 +496,24 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None else: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: vision_embeddings = self._process_image_input(image_input) - merge_multimodal_embeddings_from_map( - inputs_embeds, vision_embeddings, - attn_metadata.multi_modal_placeholder_index_maps["image"]) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) + else: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) # always pass the input via `inputs_embeds` # to make sure the computation graph is consistent # for `torch.compile` integration input_ids = None + hidden_states = self.language_model.model(input_ids, positions, kv_caches, From d918b0f3fa86398a23b86a2fef64210590096d84 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 7 Nov 2024 02:29:27 -0800 Subject: [PATCH 07/12] update Signed-off-by: Roger Wang --- vllm/model_executor/models/blip2.py | 3 ++ vllm/model_executor/models/fuyu.py | 4 ++ vllm/model_executor/models/internvl.py | 72 ++++++++++++++++++++------ 3 files changed, 62 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 6cbcd2025ec5b..6bf1d52513639 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -679,6 +679,9 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None + # TODO (ywang96): This is currently needed since embedding generation + # takes place in the model forward pass. Clean this up after V0 is + # fully deprecated. elif inputs_embeds is None: vision_embeddings = self.process_mm_inputs(**kwargs) inputs_embeds = self.get_inputs_embeds( diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 85985b3cdcb30..fc6fc2367b690 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -338,6 +338,10 @@ def forward( ): if intermediate_tensors is not None: inputs_embeds = None + + # TODO (ywang96): This is currently needed since embedding generation + # takes place in the model forward pass. Clean this up after V0 is + # fully deprecated. elif inputs_embeds is None: vision_embeddings = self.process_mm_inputs(**kwargs) inputs_embeds = self.get_inputs_embeds( diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index d2ec0ff6e74c6..fc540696245a9 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -26,7 +26,7 @@ InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.base import MultiModalInputs, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -323,9 +323,25 @@ def input_processor( num_patches) new_prompt_token_ids = tokenizer.encode(new_prompt) - return token_inputs(prompt=prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=multi_modal_data) + # Get precise tracking of placeholder positions + token_idx = image_idx = 0 + placeholder_ranges = [] + while token_idx < len(new_prompt_token_ids): + if new_prompt_token_ids[token_idx] == self.img_context_token: + curr_image_featue_size = image_feature_sizes[image_idx] + placeholder_ranges.append( + PlaceholderRange(offset=token_idx, + length=curr_image_featue_size)) + image_idx += 1 + token_idx += curr_image_featue_size + else: + token_idx += 1 + + return token_inputs( + prompt=prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": placeholder_ranges}) def input_mapper( self, @@ -608,6 +624,27 @@ def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor: visual_token_mask = None return visual_token_mask + def process_mm_inputs(self, **kwargs: object) -> Optional[torch.Tensor]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_inputs_embeds( + self, input_ids: torch.Tensor, + vision_embeddings: Optional[torch.Tensor]) -> torch.Tensor: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + if vision_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + multimodal_embeddings=vision_embeddings, + placeholder_token_id=self.img_context_token_id) + + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -615,26 +652,27 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[SamplerOutput, IntermediateTensors]: if intermediate_tensors is not None: input_ids = None inputs_embeds = None visual_token_mask = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.img_context_token_id) + + # TODO (ywang96): This is currently needed since embedding generation + # takes place in the model forward pass. Clean this up after V0 is + # fully deprecated. + elif inputs_embeds is None: + vision_embeddings = self.process_mm_inputs(**kwargs) + if vision_embeddings is not None: visual_token_mask = self._get_visual_token_mask(input_ids) - input_ids = None - else: - inputs_embeds = None - visual_token_mask = None + inputs_embeds = self.get_inputs_embeds( + input_ids=input_ids, vision_embeddings=vision_embeddings) + input_ids = None + + else: + visual_token_mask = self._get_visual_token_mask(input_ids) forward_kwargs = { "input_ids": input_ids, From d38681800623a5921a23d5ed30ba6f51e21e1b2e Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 8 Nov 2024 21:08:38 -0800 Subject: [PATCH 08/12] flatten Signed-off-by: Roger Wang --- vllm/model_executor/models/internvl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 4d62724956d2f..0ced0ba920adb 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -629,6 +629,8 @@ def process_mm_inputs(self, **kwargs: object) -> Optional[torch.Tensor]: if image_input is None: return None vision_embeddings = self._process_image_input(image_input) + hidden_size = vision_embeddings.shape[2] + vision_embeddings = vision_embeddings.reshape(1, -1, hidden_size) return vision_embeddings def get_inputs_embeds( From 816398c349a7230bcc2d7f32b7be4f950f086155 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 8 Nov 2024 21:49:04 -0800 Subject: [PATCH 09/12] fix Signed-off-by: Roger Wang --- vllm/model_executor/models/internvl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 7a9faea7a3ebf..e4205d9373956 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -26,7 +26,7 @@ InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs, PlaceholderRange +from vllm.multimodal.base import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of From 01403f9e03c79c3ad8dd31110a6e1d671526ef70 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Mon, 11 Nov 2024 02:36:31 -0800 Subject: [PATCH 10/12] fix v1 internvl Signed-off-by: Roger Wang --- vllm/model_executor/models/internvl.py | 32 +++++++++++++++++--------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index e4205d9373956..408ca548d42fd 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -322,21 +322,24 @@ def input_processor( new_prompt = self._expand_image_prompt(prompt, image_feature_sizes, num_patches) new_prompt_token_ids = tokenizer.encode(new_prompt) + img_context_token_id = tokenizer.encode(self.img_context_token)[1] # Get precise tracking of placeholder positions token_idx = image_idx = 0 placeholder_ranges = [] while token_idx < len(new_prompt_token_ids): - if new_prompt_token_ids[token_idx] == self.img_context_token: + if new_prompt_token_ids[token_idx] == img_context_token_id: curr_image_featue_size = image_feature_sizes[image_idx] + print(curr_image_featue_size) placeholder_ranges.append( PlaceholderRange(offset=token_idx, length=curr_image_featue_size)) image_idx += 1 token_idx += curr_image_featue_size + print(image_idx) + print(token_idx) else: token_idx += 1 - return token_inputs( prompt=prompt, prompt_token_ids=new_prompt_token_ids, @@ -594,12 +597,11 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - # We need to flatten (B, N, P) to (B*N*P), - # so we call flatten_bn twice. + + # Pass as a list of pixel value tensors return InternVLImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values( - flatten_bn(flatten_bn(pixel_values), concat=True)), + data=pixel_values, ) raise AssertionError("This line should be unreachable.") @@ -612,7 +614,15 @@ def _process_image_input( return image_input["data"] assert self.vision_model is not None - image_embeds = self.extract_feature(image_input["data"]) + + # Output as a list of image embeddings + image_embeds = [] + + for pixel_values in image_input["data"]: + vision_embeddings = self.extract_feature(flatten_bn(pixel_values)) + hidden_size = vision_embeddings.shape[2] + vision_embeddings = vision_embeddings.reshape(1, -1, hidden_size) + image_embeds.append(vision_embeddings) return image_embeds @@ -629,13 +639,13 @@ def process_mm_inputs(self, **kwargs: object) -> Optional[torch.Tensor]: if image_input is None: return None vision_embeddings = self._process_image_input(image_input) - hidden_size = vision_embeddings.shape[2] - vision_embeddings = vision_embeddings.reshape(1, -1, hidden_size) return vision_embeddings def get_inputs_embeds( - self, input_ids: torch.Tensor, - vision_embeddings: Optional[torch.Tensor]) -> torch.Tensor: + self, + input_ids: torch.Tensor, + vision_embeddings: Optional[torch.Tensor] = None, + ) -> torch.Tensor: inputs_embeds = self.language_model.model.get_input_embeddings( input_ids) if vision_embeddings is not None: From 1995472784be370edf135a1bbb1e38af1b81a228 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 13 Nov 2024 10:31:09 +0000 Subject: [PATCH 11/12] Fix Signed-off-by: Roger Wang --- vllm/model_executor/models/blip2.py | 6 +++--- vllm/model_executor/models/fuyu.py | 6 +++--- vllm/model_executor/models/internvl.py | 7 ++----- vllm/model_executor/models/persimmon.py | 8 +++++++- vllm/v1/core/scheduler.py | 4 ++-- 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 806cf3b75a27b..e5a64fcb49d53 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -620,10 +620,10 @@ def process_mm_inputs(self, **kwargs: object) -> Optional[torch.Tensor]: vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_inputs_embeds( + def get_input_embeddings( self, input_ids: torch.Tensor, vision_embeddings: Optional[torch.Tensor]) -> torch.Tensor: - inputs_embeds = self.language_model.model.embed_tokens(input_ids) + inputs_embeds = self.language_model.get_input_embeddings(input_ids) if vision_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( input_ids=input_ids, @@ -682,7 +682,7 @@ def forward( # fully deprecated. elif inputs_embeds is None: vision_embeddings = self.process_mm_inputs(**kwargs) - inputs_embeds = self.get_inputs_embeds( + inputs_embeds = self.get_input_embeddings( input_ids=input_ids, vision_embeddings=vision_embeddings) input_ids = None diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 4a232bf317bfc..0b9c8ac9ce42f 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -309,10 +309,10 @@ def process_mm_inputs(self, **kwargs: object) -> Optional[torch.Tensor]: vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_inputs_embeds( + def get_input_embeddings( self, input_ids: torch.Tensor, vision_embeddings: Optional[torch.Tensor]) -> torch.Tensor: - inputs_embeds = self.language_model.model.embed_tokens(input_ids) + inputs_embeds = self.language_model.get_input_embeddings(input_ids) if vision_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( input_ids=input_ids, @@ -340,7 +340,7 @@ def forward( # fully deprecated. elif inputs_embeds is None: vision_embeddings = self.process_mm_inputs(**kwargs) - inputs_embeds = self.get_inputs_embeds( + inputs_embeds = self.get_input_embeddings( input_ids=input_ids, vision_embeddings=vision_embeddings) input_ids = None diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 61a07416b6c02..4a08d7b245534 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -330,14 +330,11 @@ def input_processor( while token_idx < len(new_prompt_token_ids): if new_prompt_token_ids[token_idx] == img_context_token_id: curr_image_featue_size = image_feature_sizes[image_idx] - print(curr_image_featue_size) placeholder_ranges.append( PlaceholderRange(offset=token_idx, length=curr_image_featue_size)) image_idx += 1 token_idx += curr_image_featue_size - print(image_idx) - print(token_idx) else: token_idx += 1 return token_inputs( @@ -640,7 +637,7 @@ def process_mm_inputs(self, **kwargs: object) -> Optional[torch.Tensor]: vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_inputs_embeds( + def get_input_embeddings( self, input_ids: torch.Tensor, vision_embeddings: Optional[torch.Tensor] = None, @@ -678,7 +675,7 @@ def forward( vision_embeddings = self.process_mm_inputs(**kwargs) if vision_embeddings is not None: visual_token_mask = self._get_visual_token_mask(input_ids) - inputs_embeds = self.get_inputs_embeds( + inputs_embeds = self.get_input_embeddings( input_ids=input_ids, vision_embeddings=vision_embeddings) input_ids = None diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 2e34a7cc30873..9168944f190ce 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -235,6 +235,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -267,7 +270,7 @@ def forward( class PersimmonForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config @@ -282,6 +285,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index ba50a9786d805..ca4583593b16f 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -72,12 +72,12 @@ def __init__( # has the Transformer architecture (e.g., ViT). # FIXME(woosuk): Below are placeholder values. We need to calculate the # actual values from the configurations. - self.max_num_encoder_input_tokens = 2048 + self.max_num_encoder_input_tokens = 8192 # NOTE(woosuk): For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized and used, regardless of # the cache size. This is because the memory space for the encoder cache # is preallocated in the profiling run. - self.encoder_cache_manager = EncoderCacheManager(cache_size=2048) + self.encoder_cache_manager = EncoderCacheManager(cache_size=8192) def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: From e83892ed822bb76ed8ede198c2409b388c28a257 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 13 Nov 2024 17:06:09 +0000 Subject: [PATCH 12/12] remove ruff Signed-off-by: Roger Wang --- vllm/model_executor/models/internvl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 5a88a9921114e..02728b107c492 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -26,7 +26,7 @@ InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import NestedTensors, PlaceholderRange +from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of