From e66f22068ec461c0d932ce98ad8a578684c508c2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 14 Oct 2025 07:37:14 +0000 Subject: [PATCH 1/3] [Model] Use merge_by_field_config for MM models (O-P) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/phi3v.py | 24 +++------ vllm/model_executor/models/phi4_multimodal.py | 53 ++---------------- vllm/model_executor/models/phi4mm.py | 54 ++----------------- 3 files changed, 15 insertions(+), 116 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 93cc7af176d2..a1a5d57a343f 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -56,7 +56,6 @@ ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel @@ -70,7 +69,6 @@ AutoWeightsLoader, WeightsMapper, _merge_multimodal_embeddings, - flatten_bn, init_vllm_registered_model, maybe_prefix, ) @@ -564,6 +562,8 @@ def _apply_prompt_updates( dummy_inputs=Phi3VDummyInputsBuilder, ) class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant): + merge_by_field_config = True + hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.vision_embed_tokens.wte": "embed_tokens", @@ -631,8 +631,8 @@ def _parse_and_validate_image_input( if pixel_values is not None: return Phi3VImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values), - image_sizes=flatten_bn(image_sizes, concat=True), + pixel_values=pixel_values, + image_sizes=image_sizes, resolve_bindings={ "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size, "w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size, @@ -642,7 +642,7 @@ def _parse_and_validate_image_input( if image_embeds is not None: return Phi3VImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) raise AssertionError("This line should be unreachable.") @@ -651,20 +651,8 @@ def _process_image_input( self, image_input: Phi3VImageInputs, ) -> torch.Tensor: - if image_input["type"] == "image_embeds": - image_data = image_input["data"] - if is_list_of(image_data, torch.Tensor): - # it's already a list of tensors - return image_data - if len(image_data.shape) == 3: - # 3D tensor - return list(torch.unbind(image_data, dim=0)) - raise ValueError( - "We expect batched 2D tensors; " - "this can be either a list of 2D tensors or a single 3D tensor." - ) - assert self.vision_embed_tokens is not None + image_embeds = self.vision_embed_tokens( image_input["pixel_values"], image_input["image_sizes"] ) diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index b99e3a5a1fd8..87094d38811d 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -64,7 +64,6 @@ ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer @@ -72,7 +71,6 @@ from .utils import ( AutoWeightsLoader, WeightsMapper, - flatten_bn, init_vllm_registered_model, maybe_prefix, ) @@ -1189,6 +1187,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): Implements the Phi-4-multimodal-instruct model in vLLM. """ + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "qkv_proj", @@ -1272,9 +1272,7 @@ def _parse_and_validate_audio_input( return None if audio_features is not None: - return Phi4MMAudioFeatureInputs( - type="audio_features", data=flatten_bn(audio_features) - ) + return Phi4MMAudioFeatureInputs(type="audio_features", data=audio_features) if audio_embeds is not None: return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) @@ -1315,7 +1313,7 @@ def _process_audio_input( def _parse_and_validate_image_input( self, **kwargs: object ) -> Phi4MMImagePixelInputs | None: - image_pixel_values: NestedTensors = kwargs.get("image_pixel_values") + image_pixel_values = kwargs.get("image_pixel_values") if image_pixel_values is None: return None @@ -1328,49 +1326,6 @@ def _parse_and_validate_image_input( and num_img_tokens is not None ), "Missing image inputs" - if is_list_of(image_pixel_values, torch.Tensor): - assert all(p.dim() == 5 for p in image_pixel_values), ( - "Incorrect image inputs" - ) - # list len is batch_size. - # each tensor has dimension: num_img_per_example, num_hd_patches, - # channels, height, width. - # need to pad along num_hd_patches. - # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w. - image_pixel_values = cat_with_pad(image_pixel_values, dim=0) - elif isinstance(image_pixel_values, torch.Tensor): - # dimension: batch_size, num_img_per_example, num_hd_patches, - # channels, height, width. - # we flatten first 2 dims to make it a single large batch for - # SigLIP Encoder. - assert image_pixel_values.dim() == 6, "Incorrect image inputs" - image_pixel_values = image_pixel_values.flatten(0, 1) - else: - raise ValueError("Incorrect image_pixel_values inputs") - - if isinstance(image_attention_mask, list): - image_attention_mask = cat_with_pad(image_attention_mask, dim=0) - elif isinstance(image_attention_mask, torch.Tensor): - image_attention_mask = image_attention_mask.flatten(0, 1) - else: - raise ValueError("Incorrect image_attention_mask inputs") - - if isinstance(image_sizes, list): - image_sizes = torch.cat(image_sizes, dim=0) - elif isinstance(image_sizes, torch.Tensor): - image_sizes = image_sizes.flatten(0, 1) - else: - raise ValueError("Incorrect image_sizes inputs") - - if isinstance(num_img_tokens, list): - num_img_tokens = [ - n for num_tensor in num_img_tokens for n in num_tensor.tolist() - ] - elif isinstance(num_img_tokens, torch.Tensor): - num_img_tokens = num_img_tokens.flatten(0, 1).tolist() - else: - raise ValueError("Incorrect num_img_tokens inputs") - return Phi4MMImagePixelInputs( type="pixel_values", data=image_pixel_values, diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index dce31f9d0aac..027b4c7b2010 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -50,13 +50,12 @@ ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal from .phi4mm_audio import AudioEmbedding -from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix # <|endoftext10|> (see vocab.json in hf model) _IMAGE_PLACEHOLDER_TOKEN_ID = 200010 @@ -986,6 +985,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): Implements the Phi-4-multimodal-instruct model in vLLM. """ + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "qkv_proj", @@ -1093,9 +1094,7 @@ def _parse_and_validate_audio_input( return None if audio_features is not None: - return Phi4MMAudioFeatureInputs( - type="audio_features", data=flatten_bn(audio_features) - ) + return Phi4MMAudioFeatureInputs(type="audio_features", data=audio_features) if audio_embeds is not None: return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) @@ -1136,7 +1135,7 @@ def _process_audio_input( def _parse_and_validate_image_input( self, **kwargs: object ) -> Phi4MMImagePixelInputs | None: - input_image_embeds: NestedTensors = kwargs.get("input_image_embeds") + input_image_embeds = kwargs.get("input_image_embeds") if input_image_embeds is None: return None @@ -1149,49 +1148,6 @@ def _parse_and_validate_image_input( and num_img_tokens is not None ), "Missing image inputs" - if is_list_of(input_image_embeds, torch.Tensor): - assert all(p.dim() == 5 for p in input_image_embeds), ( - "Incorrect image inputs" - ) - # list len is batch_size. - # each tensor has dimension: num_img_per_example, num_hd_patches, - # channels, height, width. - # need to pad along num_hd_patches. - # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w. - input_image_embeds = cat_with_pad(input_image_embeds, dim=0) - elif isinstance(input_image_embeds, torch.Tensor): - # dimension: batch_size, num_img_per_example, num_hd_patches, - # channels, height, width. - # we flatten first 2 dims to make it a single large batch for - # SigLIP Encoder. - assert input_image_embeds.dim() == 6, "Incorrect image inputs" - input_image_embeds = input_image_embeds.flatten(0, 1) - else: - raise ValueError("Incorrect input_image_embeds inputs") - - if isinstance(image_attention_mask, list): - image_attention_mask = cat_with_pad(image_attention_mask, dim=0) - elif isinstance(image_attention_mask, torch.Tensor): - image_attention_mask = image_attention_mask.flatten(0, 1) - else: - raise ValueError("Incorrect image_attention_mask inputs") - - if isinstance(image_sizes, list): - image_sizes = torch.cat(image_sizes, dim=0) - elif isinstance(image_sizes, torch.Tensor): - image_sizes = image_sizes.flatten(0, 1) - else: - raise ValueError("Incorrect image_sizes inputs") - - if isinstance(num_img_tokens, list): - num_img_tokens = [ - n for num_tensor in num_img_tokens for n in num_tensor.tolist() - ] - elif isinstance(num_img_tokens, torch.Tensor): - num_img_tokens = num_img_tokens.flatten(0, 1).tolist() - else: - raise ValueError("Incorrect num_img_tokens inputs") - return Phi4MMImagePixelInputs( type="pixel_values", data=input_image_embeds, From 9946ca89be4a2d25306fc9ef33d5c3389fe01fc0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 14 Oct 2025 07:45:49 +0000 Subject: [PATCH 2/3] Fix Signed-off-by: DarkLight1337 --- vllm/model_executor/models/phi3v.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index a1a5d57a343f..b86fe67fb476 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -651,6 +651,9 @@ def _process_image_input( self, image_input: Phi3VImageInputs, ) -> torch.Tensor: + if image_input["type"] == "image_embeds": + return image_input["data"] + assert self.vision_embed_tokens is not None image_embeds = self.vision_embed_tokens( From b25b41af5bd942376d2652ec9272a1be86f85f61 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 14 Oct 2025 07:51:01 +0000 Subject: [PATCH 3/3] Improve naming Signed-off-by: DarkLight1337 --- vllm/model_executor/models/phi4_multimodal.py | 19 +++++++++++-------- vllm/model_executor/models/phi4mm.py | 19 +++++++++++-------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index 87094d38811d..4c4d6f649232 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -670,7 +670,7 @@ class Phi4MMImagePixelInputs(TensorSchema): type: Literal["pixel_values"] - data: Annotated[ + pixel_values: Annotated[ torch.Tensor | list[torch.Tensor], TensorShape( "bn", "p", 3, "h", "w", dynamic_dims={"p"} @@ -719,7 +719,7 @@ class Phi4MMAudioFeatureInputs(TensorSchema): type: Literal["audio_features"] - data: Annotated[ + audio_features: Annotated[ torch.Tensor | list[torch.Tensor], TensorShape("bn", "t", 80, dynamic_dims={"t"}), ] @@ -1272,7 +1272,10 @@ def _parse_and_validate_audio_input( return None if audio_features is not None: - return Phi4MMAudioFeatureInputs(type="audio_features", data=audio_features) + return Phi4MMAudioFeatureInputs( + type="audio_features", + audio_features=audio_features, + ) if audio_embeds is not None: return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) @@ -1296,7 +1299,7 @@ def _process_audio_input( if audio_input["type"] == "audio_embeds": return audio_input["data"] - audio_features = audio_input["data"] + audio_features = audio_input["audio_features"] # (e.g. multiple examples) and the second dim is the multi-audio dim # (e.g. multiple audios in the same example) @@ -1313,8 +1316,8 @@ def _process_audio_input( def _parse_and_validate_image_input( self, **kwargs: object ) -> Phi4MMImagePixelInputs | None: - image_pixel_values = kwargs.get("image_pixel_values") - if image_pixel_values is None: + pixel_values = kwargs.get("image_pixel_values") + if pixel_values is None: return None image_sizes = kwargs.get("image_sizes") @@ -1328,7 +1331,7 @@ def _parse_and_validate_image_input( return Phi4MMImagePixelInputs( type="pixel_values", - data=image_pixel_values, + pixel_values=pixel_values, image_sizes=image_sizes, image_attention_mask=image_attention_mask, num_img_tokens=num_img_tokens, @@ -1360,7 +1363,7 @@ def _process_image_input( image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: dtype = next(self.image_embed.parameters()).dtype - pixel_values = image_input["data"].to(dtype) + pixel_values = image_input["pixel_values"].to(dtype) image_sizes = image_input["image_sizes"] image_attention_mask = image_input["image_attention_mask"] image_embeds = self.image_embed( diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 027b4c7b2010..f2043aef67f8 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -466,7 +466,7 @@ class Phi4MMImagePixelInputs(TensorSchema): type: Literal["pixel_values"] - data: Annotated[ + pixel_values: Annotated[ torch.Tensor | list[torch.Tensor], TensorShape( "bn", "p", 3, "h", "w", dynamic_dims={"p"} @@ -498,7 +498,7 @@ class Phi4MMAudioFeatureInputs(TensorSchema): type: Literal["audio_features"] - data: Annotated[ + audio_features: Annotated[ torch.Tensor | list[torch.Tensor], TensorShape("bn", "t", 80, dynamic_dims={"t"}), ] @@ -1094,7 +1094,10 @@ def _parse_and_validate_audio_input( return None if audio_features is not None: - return Phi4MMAudioFeatureInputs(type="audio_features", data=audio_features) + return Phi4MMAudioFeatureInputs( + type="audio_features", + audio_features=audio_features, + ) if audio_embeds is not None: return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) @@ -1118,7 +1121,7 @@ def _process_audio_input( if audio_input["type"] == "audio_embeds": return audio_input["data"] - audio_features = audio_input["data"] + audio_features = audio_input["audio_features"] # (e.g. multiple examples) and the second dim is the multi-audio dim # (e.g. multiple audios in the same example) @@ -1135,8 +1138,8 @@ def _process_audio_input( def _parse_and_validate_image_input( self, **kwargs: object ) -> Phi4MMImagePixelInputs | None: - input_image_embeds = kwargs.get("input_image_embeds") - if input_image_embeds is None: + pixel_values = kwargs.get("input_image_embeds") + if pixel_values is None: return None image_sizes = kwargs.get("image_sizes") @@ -1150,7 +1153,7 @@ def _parse_and_validate_image_input( return Phi4MMImagePixelInputs( type="pixel_values", - data=input_image_embeds, + pixel_values=pixel_values, image_sizes=image_sizes, image_attention_mask=image_attention_mask, num_img_tokens=num_img_tokens, @@ -1179,7 +1182,7 @@ def _process_image_input( self, image_input: Phi4MMImagePixelInputs ) -> list[torch.Tensor]: dtype = next(self.vision_encoder.parameters()).dtype - pixel_values = image_input["data"].to(dtype) + pixel_values = image_input["pixel_values"].to(dtype) image_sizes = image_input["image_sizes"] image_attention_mask = image_input["image_attention_mask"] image_embeds = self.vision_encoder(