Skip to content

Commit afc47e4

Browse files
[Model] Use merge_by_field_config for MM models (M-N) (#26710)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent e3b90c1 commit afc47e4

File tree

11 files changed

+126
-330
lines changed

11 files changed

+126
-330
lines changed

vllm/model_executor/models/interns1.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -631,8 +631,11 @@ def _parse_and_validate_image_input(
631631
)
632632

633633
image_token_id = kwargs["image_token_id"]
634-
assert isinstance(image_token_id, torch.Tensor)
635-
self.img_context_token_id = image_token_id.flatten().unique().item()
634+
if isinstance(image_token_id, torch.Tensor):
635+
image_token_id = image_token_id.flatten().unique().item()
636+
637+
assert isinstance(image_token_id, int)
638+
self.img_context_token_id = image_token_id
636639

637640
if pixel_values is not None:
638641
h, w = self.config.vision_config.image_size
@@ -665,8 +668,11 @@ def _parse_and_validate_video_input(
665668
)
666669

667670
video_token_id = kwargs["video_token_id"]
668-
assert isinstance(video_token_id, torch.Tensor)
669-
self.video_context_token_id = video_token_id.flatten().unique().item()
671+
if isinstance(video_token_id, torch.Tensor):
672+
video_token_id = video_token_id.flatten().unique().item()
673+
674+
assert isinstance(video_token_id, int)
675+
self.video_context_token_id = video_token_id
670676

671677
if pixel_values_flat_video is not None:
672678
h, w = self.config.vision_config.image_size

vllm/model_executor/models/internvl.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,8 +1232,11 @@ def _parse_and_validate_image_input(
12321232
)
12331233

12341234
image_token_id = kwargs["image_token_id"]
1235-
assert isinstance(image_token_id, torch.Tensor)
1236-
self.img_context_token_id = image_token_id.flatten().unique().item()
1235+
if isinstance(image_token_id, torch.Tensor):
1236+
image_token_id = image_token_id.flatten().unique().item()
1237+
1238+
assert isinstance(image_token_id, int)
1239+
self.img_context_token_id = image_token_id
12371240

12381241
if pixel_values_flat is not None:
12391242
expected_h = expected_w = self.config.vision_config.image_size
@@ -1265,8 +1268,11 @@ def _parse_and_validate_video_input(
12651268
)
12661269

12671270
video_token_id = kwargs["video_token_id"]
1268-
assert isinstance(video_token_id, torch.Tensor)
1269-
self.video_context_token_id = video_token_id.flatten().unique().item()
1271+
if isinstance(video_token_id, torch.Tensor):
1272+
video_token_id = video_token_id.flatten().unique().item()
1273+
1274+
assert isinstance(video_token_id, int)
1275+
self.video_context_token_id = video_token_id
12701276

12711277
if pixel_values_flat_video is not None:
12721278
expected_h = expected_w = self.config.vision_config.image_size

vllm/model_executor/models/midashenglm.py

Lines changed: 27 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import collections
2727
import collections.abc
2828
from collections.abc import Callable, Iterable, Mapping, Sequence
29-
from typing import Any, TypeAlias, TypedDict, cast
29+
from typing import Annotated, Any, TypeAlias, cast
3030

3131
import numpy as np
3232
import torch
@@ -62,6 +62,7 @@
6262
from vllm.multimodal.profiling import BaseDummyInputsBuilder
6363
from vllm.sequence import IntermediateTensors
6464
from vllm.transformers_utils.configs.midashenglm import DashengConfig
65+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
6566

6667
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
6768
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
@@ -508,11 +509,16 @@ def forward(self, x, mask=None):
508509

509510

510511
# === Audio Inputs === #
511-
class MiDashengLMAudioInputs(TypedDict):
512-
input_values: torch.Tensor
513-
"""Shape: `(num_audios, num_sampling_points)`"""
514-
audio_length: torch.Tensor
515-
"""Shape: `(num_audios, 1)`"""
512+
class MiDashengLMAudioInputs(TensorSchema):
513+
"""
514+
515+
Dimensions:
516+
- bn: Batch size * number of audios
517+
- p: Number of sampling points
518+
"""
519+
520+
input_values: Annotated[torch.Tensor, TensorShape("n", "p")]
521+
audio_length: Annotated[torch.Tensor, TensorShape("n")]
516522

517523

518524
class MiDashengLMProcessingInfo(BaseProcessingInfo):
@@ -676,6 +682,8 @@ def get_replacement_midashenglm(item_idx: int):
676682
dummy_inputs=MiDashengLMDummyInputsBuilder,
677683
)
678684
class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
685+
merge_by_field_config = True
686+
679687
packed_modules_mapping = {
680688
"qkv_proj": [
681689
"q_proj",
@@ -728,26 +736,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
728736
self.decoder.make_empty_intermediate_tensors
729737
)
730738

731-
def _validate_and_reshape_mm_tensor(
732-
self, mm_input: object, name: str
733-
) -> torch.Tensor:
734-
if not isinstance(mm_input, (torch.Tensor, list)):
735-
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
736-
if isinstance(mm_input, torch.Tensor):
737-
return mm_input.reshape(-1, *mm_input.shape[2:])
738-
739-
if name == "input_values":
740-
max_length = max(tensor.shape[1] for tensor in mm_input)
741-
padded_mm_input = [
742-
torch.nn.functional.pad(tensor, (0, max_length - tensor.shape[1]))
743-
if tensor.shape[1] < max_length
744-
else tensor
745-
for tensor in mm_input
746-
]
747-
return torch.concat(padded_mm_input)
748-
749-
return torch.concat(mm_input)
750-
751739
def _parse_and_validate_audio_input(
752740
self, **kwargs: object
753741
) -> MiDashengLMAudioInputs | None:
@@ -756,24 +744,22 @@ def _parse_and_validate_audio_input(
756744

757745
if input_values is None:
758746
return None
759-
input_values = self._validate_and_reshape_mm_tensor(
760-
input_values, "input_values"
761-
)
762-
audio_length = self._validate_and_reshape_mm_tensor(
763-
audio_length, "audio_length"
764-
)
765-
if not isinstance(input_values, (torch.Tensor, list)):
766-
raise ValueError(
767-
"Incorrect type of audio input features. "
768-
f"Got type: {type(input_values)}"
747+
748+
if isinstance(input_values, list):
749+
input_values = torch.nn.utils.rnn.pad_sequence(
750+
input_values,
751+
batch_first=True,
769752
)
770753

771754
return MiDashengLMAudioInputs(
772755
input_values=input_values,
773756
audio_length=audio_length,
774757
)
775758

776-
def _process_audio_input(self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor:
759+
def _process_audio_input(
760+
self,
761+
audio_input: MiDashengLMAudioInputs,
762+
) -> tuple[torch.Tensor, ...]:
777763
# Process audio through encoder and projector
778764
input_values = audio_input["input_values"]
779765
audio_length = audio_input["audio_length"]
@@ -783,17 +769,13 @@ def _process_audio_input(self, audio_input: MiDashengLMAudioInputs) -> torch.Ten
783769
audio_embeddings = audio_embeddings.to(audio_input["input_values"].dtype)
784770
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
785771

786-
audio_length_np = (
787-
audio_length.cpu().numpy()
788-
if isinstance(audio_length, torch.Tensor)
789-
else audio_length
790-
)
791772
audio_output_lengths = [
792773
max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame
793-
for length in audio_length_np
774+
for length in audio_length.tolist()
794775
]
795-
audio_output_lengths = torch.tensor(audio_output_lengths).to(
796-
audio_embeddings.device
776+
audio_output_lengths = torch.tensor(
777+
audio_output_lengths,
778+
device=audio_embeddings.device,
797779
)
798780

799781
audio_feature_mask = torch.arange(
@@ -826,14 +808,6 @@ def forward(
826808
) -> torch.Tensor | IntermediateTensors:
827809
if intermediate_tensors is not None:
828810
inputs_embeds = None
829-
elif inputs_embeds is None:
830-
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
831-
inputs_embeds = self.get_input_embeddings(
832-
input_ids,
833-
multimodal_embeddings,
834-
is_multimodal=input_ids == self.config.audio_token_id,
835-
)
836-
input_ids = None
837811

838812
return self.decoder.model(
839813
input_ids,

vllm/model_executor/models/minicpmo.py

Lines changed: 5 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
MiniCPMVProcessingInfo,
7272
_minicpmv_field_config,
7373
)
74-
from .utils import AutoWeightsLoader, cast_overflow_tensors, flatten_bn, maybe_prefix
74+
from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
7575

7676
CPU_DEVICE = torch.device("cpu")
7777

@@ -132,15 +132,11 @@ class MiniCPMOAudioEmbeddingInputs(TensorSchema):
132132

133133

134134
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
135-
audio_features = hf_inputs.get("audio_features", torch.empty(0))
136-
num_audios = len(audio_features)
137-
138135
return dict(
139136
**_minicpmv_field_config(hf_inputs),
140137
audio_features=MultiModalFieldConfig.batched("audio"),
141138
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
142139
audio_embeds=MultiModalFieldConfig.batched("audio"),
143-
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
144140
)
145141

146142

@@ -332,10 +328,6 @@ def process_audios(
332328
]
333329
audio_inputs["audio_features"] = unpadded_audio_features
334330

335-
tokenizer = self.info.get_tokenizer()
336-
unk_token_id = tokenizer.get_vocab()["<unk>"]
337-
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
338-
339331
return audio_inputs
340332

341333
def process_mm_inputs(
@@ -436,12 +428,10 @@ def forward(
436428
attention_mask: torch.Tensor,
437429
) -> torch.Tensor:
438430
residual = hidden_states
439-
past_key_values = None
440431
hidden_states = self.self_attn_layer_norm(hidden_states)
441-
hidden_states, attn_weights, past_key_values = self.self_attn(
432+
hidden_states, _ = self.self_attn(
442433
hidden_states=hidden_states,
443434
attention_mask=attention_mask,
444-
past_key_value=past_key_values,
445435
)
446436
hidden_states = nn.functional.dropout(
447437
hidden_states, p=self.dropout, training=self.training
@@ -567,8 +557,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
567557
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
568558
)
569559

570-
self.audio_token_id = None
571-
572560
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
573561
# Do not use parameters temporarily
574562
audio_config = self.config.audio_config
@@ -731,43 +719,18 @@ def _parse_and_validate_audio_input(
731719
if audio_features is None and audio_embeds is None:
732720
return None
733721

734-
audio_token_id = kwargs.pop("audio_token_id")
735-
if audio_token_id is not None:
736-
assert isinstance(audio_token_id, torch.Tensor)
737-
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
738-
739722
if audio_embeds is not None:
740-
if not isinstance(audio_embeds, (torch.Tensor, list)):
741-
raise ValueError(
742-
f"Incorrect type of audio_embeds. Got type: {type(audio_embeds)}"
743-
)
744-
745-
audio_embeds_flat = flatten_bn(audio_embeds)
746-
747723
return MiniCPMOAudioEmbeddingInputs(
748724
type="audio_embeds",
749-
audio_embeds=audio_embeds_flat,
750-
)
751-
752-
if not isinstance(audio_features, (torch.Tensor, list)):
753-
raise ValueError(
754-
f"Incorrect type of audio_features. Got type: {type(audio_features)}"
725+
audio_embeds=audio_embeds,
755726
)
756727

757728
audio_feature_lens = kwargs.pop("audio_feature_lens")
758-
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
759-
raise ValueError(
760-
"Incorrect type of audio_feature_lens. "
761-
f"Got type: {type(audio_feature_lens)}"
762-
)
763-
764-
audio_features_flat = flatten_bn(audio_features)
765-
audio_feature_lens_flat = flatten_bn(audio_feature_lens)
766729

767730
return MiniCPMOAudioFeatureInputs(
768731
type="audio_features",
769-
audio_features=audio_features_flat,
770-
audio_feature_lens=audio_feature_lens_flat,
732+
audio_features=audio_features,
733+
audio_feature_lens=audio_feature_lens,
771734
)
772735

773736
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:

0 commit comments

Comments
 (0)