Skip to content

Commit 2e461a0

Browse files
DarkLight1337ilmarkov
authored andcommitted
[Model] Use merge_by_field_config for MM models (Qwen series) (vllm-project#27546)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent a3df51b commit 2e461a0

File tree

7 files changed

+36
-305
lines changed

7 files changed

+36
-305
lines changed

vllm/model_executor/models/qwen2_5_omni_thinker.py

Lines changed: 15 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,12 @@ class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
126126
type: Literal["audio_features"]
127127
input_features: Annotated[
128128
torch.Tensor | list[torch.Tensor],
129-
TensorShape("nmb", "tsl"),
129+
TensorShape("nmb", "tsl", dynamic_dims={"tsl"}),
130130
]
131131

132132
feature_attention_mask: Annotated[
133-
torch.Tensor,
134-
TensorShape("na", "msl"),
133+
torch.Tensor | list[torch.Tensor],
134+
TensorShape("na", "msl", dynamic_dims={"msl"}),
135135
]
136136

137137

@@ -651,18 +651,6 @@ def _validate_mm_placeholders(
651651

652652

653653
class Qwen2_5OmniConditionalGenerationMixin:
654-
def _validate_and_reshape_mm_tensor(
655-
self, mm_input: object, name: str, dim: int = 0
656-
) -> torch.Tensor:
657-
if not isinstance(mm_input, (torch.Tensor, list)):
658-
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
659-
if isinstance(mm_input, torch.Tensor):
660-
if dim == 0:
661-
return mm_input.reshape(-1, *mm_input.shape[2:])
662-
return torch.concat(list(mm_input), dim=dim)
663-
else:
664-
return torch.concat(mm_input, dim=dim)
665-
666654
def _parse_and_validate_audio_input(
667655
self, **kwargs: object
668656
) -> Qwen2_5OmniAudioFeatureInputs | None:
@@ -671,18 +659,7 @@ def _parse_and_validate_audio_input(
671659
feature_attention_mask = kwargs.pop("feature_attention_mask", None)
672660
if input_audio_features is None:
673661
return None
674-
input_audio_features = self._validate_and_reshape_mm_tensor(
675-
input_audio_features, "input_audio_features", dim=1
676-
)
677-
if feature_attention_mask is not None:
678-
feature_attention_mask = self._validate_and_reshape_mm_tensor(
679-
feature_attention_mask, "feature_attention_mask"
680-
)
681-
if not isinstance(input_audio_features, (torch.Tensor, list)):
682-
raise ValueError(
683-
"Incorrect type of audio input features. "
684-
f"Got type: {type(input_audio_features)}"
685-
)
662+
686663
return Qwen2_5OmniAudioFeatureInputs(
687664
type="audio_features",
688665
input_features=input_audio_features,
@@ -702,38 +679,13 @@ def _parse_and_validate_image_input(
702679
return None
703680

704681
if pixel_values is not None:
705-
pixel_values = self._validate_and_reshape_mm_tensor(
706-
pixel_values, "image pixel values"
707-
)
708-
image_grid_thw = self._validate_and_reshape_mm_tensor(
709-
image_grid_thw, "image grid_thw"
710-
)
711-
712-
if not isinstance(pixel_values, (torch.Tensor, list)):
713-
raise ValueError(
714-
"Incorrect type of image pixel values. "
715-
f"Got type: {type(pixel_values)}"
716-
)
717-
718682
return Qwen2_5_VLImagePixelInputs(
719683
type="pixel_values",
720684
pixel_values=pixel_values,
721685
image_grid_thw=image_grid_thw,
722686
)
723687

724688
if image_embeds is not None:
725-
image_embeds = self._validate_and_reshape_mm_tensor(
726-
image_embeds, "image embeds"
727-
)
728-
image_grid_thw = self._validate_and_reshape_mm_tensor(
729-
image_grid_thw, "image grid_thw"
730-
)
731-
732-
if not isinstance(image_embeds, torch.Tensor):
733-
raise ValueError(
734-
"Incorrect type of image embeddings. "
735-
f"Got type: {type(image_embeds)}"
736-
)
737689
return Qwen2_5_VLImageEmbeddingInputs(
738690
type="image_embeds",
739691
image_embeds=image_embeds,
@@ -752,27 +704,13 @@ def _parse_and_validate_video_input(
752704
return None
753705

754706
if pixel_values_videos is not None:
755-
pixel_values_videos = self._validate_and_reshape_mm_tensor(
756-
pixel_values_videos, "video pixel values"
757-
)
758-
video_grid_thw = self._validate_and_reshape_mm_tensor(
759-
video_grid_thw, "video grid_thw"
760-
)
761-
762707
return Qwen2_5_VLVideoPixelInputs(
763708
type="pixel_values_videos",
764709
pixel_values_videos=pixel_values_videos,
765710
video_grid_thw=video_grid_thw,
766711
)
767712

768713
if video_embeds is not None:
769-
video_embeds = self._validate_and_reshape_mm_tensor(
770-
video_embeds, "video embeds"
771-
)
772-
video_grid_thw = self._validate_and_reshape_mm_tensor(
773-
video_grid_thw, "video grid_thw"
774-
)
775-
776714
if not isinstance(video_embeds, torch.Tensor):
777715
raise ValueError(
778716
"Incorrect type of video embeddings. "
@@ -787,23 +725,18 @@ def _parse_and_validate_video_input(
787725
def _process_audio_input(
788726
self,
789727
audio_input: Qwen2_5OmniAudioFeatureInputs,
790-
audio_hashes: list[str] = None,
791-
cached_audio_features: torch.Tensor = None,
728+
audio_hashes: list[str] | None = None,
729+
cached_audio_features: torch.Tensor | None = None,
792730
) -> torch.Tensor:
793731
input_features = audio_input["input_features"]
794732
audio_feature_lengths = audio_input["audio_feature_lengths"]
795-
if input_features.ndim == 3:
796-
assert input_features.shape[0] == 1
797-
input_features = input_features.squeeze(0)
798-
if audio_feature_lengths.ndim == 2:
799-
assert (
800-
audio_feature_lengths.shape[0] == 1
801-
or audio_feature_lengths.shape[1] == 1
802-
)
803-
if audio_feature_lengths.shape[0] == 1:
804-
audio_feature_lengths = audio_feature_lengths.squeeze(0)
805-
else:
806-
audio_feature_lengths = audio_feature_lengths.squeeze(1)
733+
734+
if audio_feature_lengths.shape[0] == 1:
735+
audio_feature_lengths = audio_feature_lengths.squeeze(0)
736+
elif audio_feature_lengths.shape[1] == 1:
737+
audio_feature_lengths = audio_feature_lengths.squeeze(1)
738+
else:
739+
raise AssertionError(audio_feature_lengths.shape)
807740

808741
audio_feat_lengths, audio_output_lengths = (
809742
self.audio_tower._get_feat_extract_output_lengths(audio_feature_lengths)
@@ -867,6 +800,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
867800
SupportsMRoPE,
868801
Qwen2_5OmniConditionalGenerationMixin,
869802
):
803+
merge_by_field_config = True
804+
870805
hf_to_vllm_mapper = WeightsMapper(
871806
orig_to_new_prefix={
872807
"thinker.lm_head.": "language_model.lm_head.",

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 2 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,8 @@ class Qwen2_5_VLForConditionalGeneration(
10711071
SupportsMultiModalPruning,
10721072
SupportsMRoPE,
10731073
):
1074+
merge_by_field_config = True
1075+
10741076
packed_modules_mapping = {
10751077
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
10761078
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -1273,24 +1275,6 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
12731275
num_layers = len(self.language_model.model.layers)
12741276
return (2, num_layers // 2, num_layers - 3)
12751277

1276-
def _validate_and_reshape_mm_tensor(
1277-
self, mm_input: object, name: str
1278-
) -> torch.Tensor:
1279-
if not isinstance(mm_input, (torch.Tensor, list)):
1280-
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
1281-
if isinstance(mm_input, torch.Tensor):
1282-
if mm_input.ndim == 2:
1283-
return mm_input
1284-
if mm_input.ndim != 3:
1285-
raise ValueError(
1286-
f"{name} should be 2D or batched 3D tensor. "
1287-
f"Got ndim: {mm_input.ndim} "
1288-
f"(shape={mm_input.shape})"
1289-
)
1290-
return mm_input.reshape(-1, mm_input.shape[-1])
1291-
else:
1292-
return torch.concat(mm_input)
1293-
12941278
def _parse_and_validate_image_input(
12951279
self, **kwargs: object
12961280
) -> Qwen2_5_VLImageInputs | None:
@@ -1302,27 +1286,13 @@ def _parse_and_validate_image_input(
13021286
return None
13031287

13041288
if pixel_values is not None:
1305-
pixel_values = self._validate_and_reshape_mm_tensor(
1306-
pixel_values, "image pixel values"
1307-
)
1308-
image_grid_thw = self._validate_and_reshape_mm_tensor(
1309-
image_grid_thw, "image grid_thw"
1310-
)
1311-
13121289
return Qwen2_5_VLImagePixelInputs(
13131290
type="pixel_values",
13141291
pixel_values=pixel_values,
13151292
image_grid_thw=image_grid_thw,
13161293
)
13171294

13181295
if image_embeds is not None:
1319-
image_embeds = self._validate_and_reshape_mm_tensor(
1320-
image_embeds, "image embeds"
1321-
)
1322-
image_grid_thw = self._validate_and_reshape_mm_tensor(
1323-
image_grid_thw, "image grid_thw"
1324-
)
1325-
13261296
return Qwen2_5_VLImageEmbeddingInputs(
13271297
type="image_embeds",
13281298
image_embeds=image_embeds,
@@ -1341,14 +1311,6 @@ def _parse_and_validate_video_input(
13411311
return None
13421312

13431313
if pixel_values_videos is not None:
1344-
pixel_values_videos = self._validate_and_reshape_mm_tensor(
1345-
pixel_values_videos, "video pixel values"
1346-
)
1347-
video_grid_thw = self._validate_and_reshape_mm_tensor(
1348-
video_grid_thw, "video grid_thw"
1349-
)
1350-
if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2:
1351-
second_per_grid_ts = second_per_grid_ts.squeeze(-1)
13521314
return Qwen2_5_VLVideoPixelInputs(
13531315
type="pixel_values_videos",
13541316
pixel_values_videos=pixel_values_videos,
@@ -1357,13 +1319,6 @@ def _parse_and_validate_video_input(
13571319
)
13581320

13591321
if video_embeds is not None:
1360-
video_embeds = self._validate_and_reshape_mm_tensor(
1361-
video_embeds, "video embeds"
1362-
)
1363-
video_grid_thw = self._validate_and_reshape_mm_tensor(
1364-
video_grid_thw, "video grid_thw"
1365-
)
1366-
13671322
return Qwen2_5_VLVideoEmbeddingInputs(
13681323
type="video_embeds",
13691324
video_embeds=video_embeds,

vllm/model_executor/models/qwen2_audio.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ def get_replacement_qwen2_audio(item_idx: int):
313313
dummy_inputs=Qwen2AudioDummyInputsBuilder,
314314
)
315315
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
316+
merge_by_field_config = True
317+
316318
@classmethod
317319
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
318320
if modality.startswith("audio"):
@@ -346,16 +348,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
346348
self.language_model.make_empty_intermediate_tensors
347349
)
348350

349-
def _validate_and_reshape_mm_tensor(
350-
self, mm_input: object, name: str
351-
) -> torch.Tensor:
352-
if not isinstance(mm_input, (torch.Tensor, list)):
353-
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
354-
if isinstance(mm_input, torch.Tensor):
355-
return mm_input.reshape(-1, *mm_input.shape[2:])
356-
else:
357-
return torch.concat(mm_input)
358-
359351
def _parse_and_validate_audio_input(
360352
self, **kwargs: object
361353
) -> Qwen2AudioInputs | None:
@@ -367,24 +359,11 @@ def _parse_and_validate_audio_input(
367359
return None
368360

369361
if audio_embeds is not None:
370-
if not isinstance(audio_embeds, (torch.Tensor, list)):
371-
raise ValueError(
372-
f"Incorrect type of audio embeds. Got type: {type(audio_embeds)}"
373-
)
374-
audio_embeds = self._validate_and_reshape_mm_tensor(
375-
audio_embeds, "audio_embeds"
376-
)
377362
return Qwen2AudioEmbeddingInputs(
378363
type="audio_embeds", audio_embeds=audio_embeds
379364
)
380365

381366
if input_features is not None:
382-
input_features = self._validate_and_reshape_mm_tensor(
383-
input_features, "input_features"
384-
)
385-
feature_attention_mask = self._validate_and_reshape_mm_tensor(
386-
feature_attention_mask, "feature_attention_mask"
387-
)
388367
return Qwen2AudioFeatureInputs(
389368
type="audio_features",
390369
input_features=input_features,

vllm/model_executor/models/qwen2_vl.py

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,8 @@ def _get_mm_fields_config(
12131213
class Qwen2VLForConditionalGeneration(
12141214
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
12151215
):
1216+
merge_by_field_config = True
1217+
12161218
# To ensure correct weight loading and mapping.
12171219
hf_to_vllm_mapper = WeightsMapper(
12181220
orig_to_new_prefix={
@@ -1406,24 +1408,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
14061408
self.language_model.make_empty_intermediate_tensors
14071409
)
14081410

1409-
def _validate_and_reshape_mm_tensor(
1410-
self, mm_input: object, name: str
1411-
) -> torch.Tensor:
1412-
if not isinstance(mm_input, (torch.Tensor, list)):
1413-
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
1414-
if isinstance(mm_input, torch.Tensor):
1415-
if mm_input.ndim == 2:
1416-
return mm_input
1417-
if mm_input.ndim != 3:
1418-
raise ValueError(
1419-
f"{name} should be 2D or batched 3D tensor. "
1420-
f"Got ndim: {mm_input.ndim} "
1421-
f"(shape={mm_input.shape})"
1422-
)
1423-
return mm_input.reshape(-1, mm_input.shape[-1])
1424-
else:
1425-
return torch.concat(mm_input)
1426-
14271411
def _parse_and_validate_image_input(
14281412
self, **kwargs: object
14291413
) -> Qwen2VLImageInputs | None:
@@ -1435,27 +1419,13 @@ def _parse_and_validate_image_input(
14351419
return None
14361420

14371421
if pixel_values is not None:
1438-
pixel_values = self._validate_and_reshape_mm_tensor(
1439-
pixel_values, "image pixel values"
1440-
)
1441-
image_grid_thw = self._validate_and_reshape_mm_tensor(
1442-
image_grid_thw, "image grid_thw"
1443-
)
1444-
14451422
return Qwen2VLImagePixelInputs(
14461423
type="pixel_values",
14471424
pixel_values=pixel_values,
14481425
image_grid_thw=image_grid_thw,
14491426
)
14501427

14511428
if image_embeds is not None:
1452-
image_embeds = self._validate_and_reshape_mm_tensor(
1453-
image_embeds, "image embeds"
1454-
)
1455-
image_grid_thw = self._validate_and_reshape_mm_tensor(
1456-
image_grid_thw, "image grid_thw"
1457-
)
1458-
14591429
return Qwen2VLImageEmbeddingInputs(
14601430
type="image_embeds",
14611431
image_embeds=image_embeds,
@@ -1473,27 +1443,13 @@ def _parse_and_validate_video_input(
14731443
return None
14741444

14751445
if pixel_values_videos is not None:
1476-
pixel_values_videos = self._validate_and_reshape_mm_tensor(
1477-
pixel_values_videos, "video pixel values"
1478-
)
1479-
video_grid_thw = self._validate_and_reshape_mm_tensor(
1480-
video_grid_thw, "video grid_thw"
1481-
)
1482-
14831446
return Qwen2VLVideoPixelInputs(
14841447
type="pixel_values_videos",
14851448
pixel_values_videos=pixel_values_videos,
14861449
video_grid_thw=video_grid_thw,
14871450
)
14881451

14891452
if video_embeds is not None:
1490-
video_embeds = self._validate_and_reshape_mm_tensor(
1491-
video_embeds, "video embeds"
1492-
)
1493-
video_grid_thw = self._validate_and_reshape_mm_tensor(
1494-
video_grid_thw, "video grid_thw"
1495-
)
1496-
14971453
return Qwen2VLVideoEmbeddingInputs(
14981454
type="video_embeds",
14991455
video_embeds=video_embeds,

0 commit comments

Comments
 (0)