Skip to content

Commit ca8a702

Browse files
DarkLight1337xuebwang-amd
authored andcommitted
[Model] Use merge_by_field_config for MM models (O-P) (vllm-project#26776)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 287c8fa commit ca8a702

File tree

3 files changed

+30
-122
lines changed

3 files changed

+30
-122
lines changed

vllm/model_executor/models/phi3v.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
)
5757
from vllm.multimodal.profiling import BaseDummyInputsBuilder
5858
from vllm.sequence import IntermediateTensors
59-
from vllm.utils import is_list_of
6059
from vllm.utils.tensor_schema import TensorSchema, TensorShape
6160

6261
from .clip import CLIPVisionModel
@@ -70,7 +69,6 @@
7069
AutoWeightsLoader,
7170
WeightsMapper,
7271
_merge_multimodal_embeddings,
73-
flatten_bn,
7472
init_vllm_registered_model,
7573
maybe_prefix,
7674
)
@@ -564,6 +562,8 @@ def _apply_prompt_updates(
564562
dummy_inputs=Phi3VDummyInputsBuilder,
565563
)
566564
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant):
565+
merge_by_field_config = True
566+
567567
hf_to_vllm_mapper = WeightsMapper(
568568
orig_to_new_prefix={
569569
"model.vision_embed_tokens.wte": "embed_tokens",
@@ -631,8 +631,8 @@ def _parse_and_validate_image_input(
631631
if pixel_values is not None:
632632
return Phi3VImagePixelInputs(
633633
type="pixel_values",
634-
pixel_values=flatten_bn(pixel_values),
635-
image_sizes=flatten_bn(image_sizes, concat=True),
634+
pixel_values=pixel_values,
635+
image_sizes=image_sizes,
636636
resolve_bindings={
637637
"h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
638638
"w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
@@ -642,7 +642,7 @@ def _parse_and_validate_image_input(
642642
if image_embeds is not None:
643643
return Phi3VImageEmbeddingInputs(
644644
type="image_embeds",
645-
data=flatten_bn(image_embeds),
645+
data=image_embeds,
646646
)
647647

648648
raise AssertionError("This line should be unreachable.")
@@ -652,19 +652,10 @@ def _process_image_input(
652652
image_input: Phi3VImageInputs,
653653
) -> torch.Tensor:
654654
if image_input["type"] == "image_embeds":
655-
image_data = image_input["data"]
656-
if is_list_of(image_data, torch.Tensor):
657-
# it's already a list of tensors
658-
return image_data
659-
if len(image_data.shape) == 3:
660-
# 3D tensor
661-
return list(torch.unbind(image_data, dim=0))
662-
raise ValueError(
663-
"We expect batched 2D tensors; "
664-
"this can be either a list of 2D tensors or a single 3D tensor."
665-
)
655+
return image_input["data"]
666656

667657
assert self.vision_embed_tokens is not None
658+
668659
image_embeds = self.vision_embed_tokens(
669660
image_input["pixel_values"], image_input["image_sizes"]
670661
)

vllm/model_executor/models/phi4_multimodal.py

Lines changed: 11 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,13 @@
6464
)
6565
from vllm.multimodal.profiling import BaseDummyInputsBuilder
6666
from vllm.sequence import IntermediateTensors
67-
from vllm.utils import is_list_of
6867
from vllm.utils.tensor_schema import TensorSchema, TensorShape
6968

7069
from .idefics2_vision_model import Idefics2VisionTransformer
7170
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
7271
from .utils import (
7372
AutoWeightsLoader,
7473
WeightsMapper,
75-
flatten_bn,
7674
init_vllm_registered_model,
7775
maybe_prefix,
7876
)
@@ -672,7 +670,7 @@ class Phi4MMImagePixelInputs(TensorSchema):
672670

673671
type: Literal["pixel_values"]
674672

675-
data: Annotated[
673+
pixel_values: Annotated[
676674
torch.Tensor | list[torch.Tensor],
677675
TensorShape(
678676
"bn", "p", 3, "h", "w", dynamic_dims={"p"}
@@ -721,7 +719,7 @@ class Phi4MMAudioFeatureInputs(TensorSchema):
721719

722720
type: Literal["audio_features"]
723721

724-
data: Annotated[
722+
audio_features: Annotated[
725723
torch.Tensor | list[torch.Tensor],
726724
TensorShape("bn", "t", 80, dynamic_dims={"t"}),
727725
]
@@ -1189,6 +1187,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
11891187
Implements the Phi-4-multimodal-instruct model in vLLM.
11901188
"""
11911189

1190+
merge_by_field_config = True
1191+
11921192
packed_modules_mapping = {
11931193
"qkv_proj": [
11941194
"qkv_proj",
@@ -1273,7 +1273,8 @@ def _parse_and_validate_audio_input(
12731273

12741274
if audio_features is not None:
12751275
return Phi4MMAudioFeatureInputs(
1276-
type="audio_features", data=flatten_bn(audio_features)
1276+
type="audio_features",
1277+
audio_features=audio_features,
12771278
)
12781279

12791280
if audio_embeds is not None:
@@ -1298,7 +1299,7 @@ def _process_audio_input(
12981299
if audio_input["type"] == "audio_embeds":
12991300
return audio_input["data"]
13001301

1301-
audio_features = audio_input["data"]
1302+
audio_features = audio_input["audio_features"]
13021303
# (e.g. multiple examples) and the second dim is the multi-audio dim
13031304
# (e.g. multiple audios in the same example)
13041305

@@ -1315,8 +1316,8 @@ def _process_audio_input(
13151316
def _parse_and_validate_image_input(
13161317
self, **kwargs: object
13171318
) -> Phi4MMImagePixelInputs | None:
1318-
image_pixel_values: NestedTensors = kwargs.get("image_pixel_values")
1319-
if image_pixel_values is None:
1319+
pixel_values = kwargs.get("image_pixel_values")
1320+
if pixel_values is None:
13201321
return None
13211322

13221323
image_sizes = kwargs.get("image_sizes")
@@ -1328,52 +1329,9 @@ def _parse_and_validate_image_input(
13281329
and num_img_tokens is not None
13291330
), "Missing image inputs"
13301331

1331-
if is_list_of(image_pixel_values, torch.Tensor):
1332-
assert all(p.dim() == 5 for p in image_pixel_values), (
1333-
"Incorrect image inputs"
1334-
)
1335-
# list len is batch_size.
1336-
# each tensor has dimension: num_img_per_example, num_hd_patches,
1337-
# channels, height, width.
1338-
# need to pad along num_hd_patches.
1339-
# mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
1340-
image_pixel_values = cat_with_pad(image_pixel_values, dim=0)
1341-
elif isinstance(image_pixel_values, torch.Tensor):
1342-
# dimension: batch_size, num_img_per_example, num_hd_patches,
1343-
# channels, height, width.
1344-
# we flatten first 2 dims to make it a single large batch for
1345-
# SigLIP Encoder.
1346-
assert image_pixel_values.dim() == 6, "Incorrect image inputs"
1347-
image_pixel_values = image_pixel_values.flatten(0, 1)
1348-
else:
1349-
raise ValueError("Incorrect image_pixel_values inputs")
1350-
1351-
if isinstance(image_attention_mask, list):
1352-
image_attention_mask = cat_with_pad(image_attention_mask, dim=0)
1353-
elif isinstance(image_attention_mask, torch.Tensor):
1354-
image_attention_mask = image_attention_mask.flatten(0, 1)
1355-
else:
1356-
raise ValueError("Incorrect image_attention_mask inputs")
1357-
1358-
if isinstance(image_sizes, list):
1359-
image_sizes = torch.cat(image_sizes, dim=0)
1360-
elif isinstance(image_sizes, torch.Tensor):
1361-
image_sizes = image_sizes.flatten(0, 1)
1362-
else:
1363-
raise ValueError("Incorrect image_sizes inputs")
1364-
1365-
if isinstance(num_img_tokens, list):
1366-
num_img_tokens = [
1367-
n for num_tensor in num_img_tokens for n in num_tensor.tolist()
1368-
]
1369-
elif isinstance(num_img_tokens, torch.Tensor):
1370-
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
1371-
else:
1372-
raise ValueError("Incorrect num_img_tokens inputs")
1373-
13741332
return Phi4MMImagePixelInputs(
13751333
type="pixel_values",
1376-
data=image_pixel_values,
1334+
pixel_values=pixel_values,
13771335
image_sizes=image_sizes,
13781336
image_attention_mask=image_attention_mask,
13791337
num_img_tokens=num_img_tokens,
@@ -1405,7 +1363,7 @@ def _process_image_input(
14051363
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
14061364
else:
14071365
dtype = next(self.image_embed.parameters()).dtype
1408-
pixel_values = image_input["data"].to(dtype)
1366+
pixel_values = image_input["pixel_values"].to(dtype)
14091367
image_sizes = image_input["image_sizes"]
14101368
image_attention_mask = image_input["image_attention_mask"]
14111369
image_embeds = self.image_embed(

vllm/model_executor/models/phi4mm.py

Lines changed: 12 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,12 @@
5050
)
5151
from vllm.multimodal.profiling import BaseDummyInputsBuilder
5252
from vllm.sequence import IntermediateTensors
53-
from vllm.utils import is_list_of
5453
from vllm.utils.tensor_schema import TensorSchema, TensorShape
5554

5655
from .idefics2_vision_model import Idefics2VisionTransformer
5756
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
5857
from .phi4mm_audio import AudioEmbedding
59-
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
58+
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
6059

6160
# <|endoftext10|> (see vocab.json in hf model)
6261
_IMAGE_PLACEHOLDER_TOKEN_ID = 200010
@@ -467,7 +466,7 @@ class Phi4MMImagePixelInputs(TensorSchema):
467466

468467
type: Literal["pixel_values"]
469468

470-
data: Annotated[
469+
pixel_values: Annotated[
471470
torch.Tensor | list[torch.Tensor],
472471
TensorShape(
473472
"bn", "p", 3, "h", "w", dynamic_dims={"p"}
@@ -499,7 +498,7 @@ class Phi4MMAudioFeatureInputs(TensorSchema):
499498

500499
type: Literal["audio_features"]
501500

502-
data: Annotated[
501+
audio_features: Annotated[
503502
torch.Tensor | list[torch.Tensor],
504503
TensorShape("bn", "t", 80, dynamic_dims={"t"}),
505504
]
@@ -986,6 +985,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
986985
Implements the Phi-4-multimodal-instruct model in vLLM.
987986
"""
988987

988+
merge_by_field_config = True
989+
989990
packed_modules_mapping = {
990991
"qkv_proj": [
991992
"qkv_proj",
@@ -1094,7 +1095,8 @@ def _parse_and_validate_audio_input(
10941095

10951096
if audio_features is not None:
10961097
return Phi4MMAudioFeatureInputs(
1097-
type="audio_features", data=flatten_bn(audio_features)
1098+
type="audio_features",
1099+
audio_features=audio_features,
10981100
)
10991101

11001102
if audio_embeds is not None:
@@ -1119,7 +1121,7 @@ def _process_audio_input(
11191121
if audio_input["type"] == "audio_embeds":
11201122
return audio_input["data"]
11211123

1122-
audio_features = audio_input["data"]
1124+
audio_features = audio_input["audio_features"]
11231125
# (e.g. multiple examples) and the second dim is the multi-audio dim
11241126
# (e.g. multiple audios in the same example)
11251127

@@ -1136,8 +1138,8 @@ def _process_audio_input(
11361138
def _parse_and_validate_image_input(
11371139
self, **kwargs: object
11381140
) -> Phi4MMImagePixelInputs | None:
1139-
input_image_embeds: NestedTensors = kwargs.get("input_image_embeds")
1140-
if input_image_embeds is None:
1141+
pixel_values = kwargs.get("input_image_embeds")
1142+
if pixel_values is None:
11411143
return None
11421144

11431145
image_sizes = kwargs.get("image_sizes")
@@ -1149,52 +1151,9 @@ def _parse_and_validate_image_input(
11491151
and num_img_tokens is not None
11501152
), "Missing image inputs"
11511153

1152-
if is_list_of(input_image_embeds, torch.Tensor):
1153-
assert all(p.dim() == 5 for p in input_image_embeds), (
1154-
"Incorrect image inputs"
1155-
)
1156-
# list len is batch_size.
1157-
# each tensor has dimension: num_img_per_example, num_hd_patches,
1158-
# channels, height, width.
1159-
# need to pad along num_hd_patches.
1160-
# mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
1161-
input_image_embeds = cat_with_pad(input_image_embeds, dim=0)
1162-
elif isinstance(input_image_embeds, torch.Tensor):
1163-
# dimension: batch_size, num_img_per_example, num_hd_patches,
1164-
# channels, height, width.
1165-
# we flatten first 2 dims to make it a single large batch for
1166-
# SigLIP Encoder.
1167-
assert input_image_embeds.dim() == 6, "Incorrect image inputs"
1168-
input_image_embeds = input_image_embeds.flatten(0, 1)
1169-
else:
1170-
raise ValueError("Incorrect input_image_embeds inputs")
1171-
1172-
if isinstance(image_attention_mask, list):
1173-
image_attention_mask = cat_with_pad(image_attention_mask, dim=0)
1174-
elif isinstance(image_attention_mask, torch.Tensor):
1175-
image_attention_mask = image_attention_mask.flatten(0, 1)
1176-
else:
1177-
raise ValueError("Incorrect image_attention_mask inputs")
1178-
1179-
if isinstance(image_sizes, list):
1180-
image_sizes = torch.cat(image_sizes, dim=0)
1181-
elif isinstance(image_sizes, torch.Tensor):
1182-
image_sizes = image_sizes.flatten(0, 1)
1183-
else:
1184-
raise ValueError("Incorrect image_sizes inputs")
1185-
1186-
if isinstance(num_img_tokens, list):
1187-
num_img_tokens = [
1188-
n for num_tensor in num_img_tokens for n in num_tensor.tolist()
1189-
]
1190-
elif isinstance(num_img_tokens, torch.Tensor):
1191-
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
1192-
else:
1193-
raise ValueError("Incorrect num_img_tokens inputs")
1194-
11951154
return Phi4MMImagePixelInputs(
11961155
type="pixel_values",
1197-
data=input_image_embeds,
1156+
pixel_values=pixel_values,
11981157
image_sizes=image_sizes,
11991158
image_attention_mask=image_attention_mask,
12001159
num_img_tokens=num_img_tokens,
@@ -1223,7 +1182,7 @@ def _process_image_input(
12231182
self, image_input: Phi4MMImagePixelInputs
12241183
) -> list[torch.Tensor]:
12251184
dtype = next(self.vision_encoder.parameters()).dtype
1226-
pixel_values = image_input["data"].to(dtype)
1185+
pixel_values = image_input["pixel_values"].to(dtype)
12271186
image_sizes = image_input["image_sizes"]
12281187
image_attention_mask = image_input["image_attention_mask"]
12291188
image_embeds = self.vision_encoder(

0 commit comments

Comments
 (0)