2626import collections
2727import collections .abc
2828from collections .abc import Callable , Iterable , Mapping , Sequence
29- from typing import Any , TypeAlias , TypedDict , cast
29+ from typing import Annotated , Any , TypeAlias , cast
3030
3131import numpy as np
3232import torch
6262from vllm .multimodal .profiling import BaseDummyInputsBuilder
6363from vllm .sequence import IntermediateTensors
6464from vllm .transformers_utils .configs .midashenglm import DashengConfig
65+ from vllm .utils .tensor_schema import TensorSchema , TensorShape
6566
6667from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
6768from .utils import AutoWeightsLoader , init_vllm_registered_model , maybe_prefix
@@ -508,11 +509,16 @@ def forward(self, x, mask=None):
508509
509510
510511# === Audio Inputs === #
511- class MiDashengLMAudioInputs (TypedDict ):
512- input_values : torch .Tensor
513- """Shape: `(num_audios, num_sampling_points)`"""
514- audio_length : torch .Tensor
515- """Shape: `(num_audios, 1)`"""
512+ class MiDashengLMAudioInputs (TensorSchema ):
513+ """
514+
515+ Dimensions:
516+ - bn: Batch size * number of audios
517+ - p: Number of sampling points
518+ """
519+
520+ input_values : Annotated [torch .Tensor , TensorShape ("n" , "p" )]
521+ audio_length : Annotated [torch .Tensor , TensorShape ("n" )]
516522
517523
518524class MiDashengLMProcessingInfo (BaseProcessingInfo ):
@@ -676,6 +682,8 @@ def get_replacement_midashenglm(item_idx: int):
676682 dummy_inputs = MiDashengLMDummyInputsBuilder ,
677683)
678684class MiDashengLMModel (nn .Module , SupportsMultiModal , SupportsPP ):
685+ merge_by_field_config = True
686+
679687 packed_modules_mapping = {
680688 "qkv_proj" : [
681689 "q_proj" ,
@@ -728,26 +736,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
728736 self .decoder .make_empty_intermediate_tensors
729737 )
730738
731- def _validate_and_reshape_mm_tensor (
732- self , mm_input : object , name : str
733- ) -> torch .Tensor :
734- if not isinstance (mm_input , (torch .Tensor , list )):
735- raise ValueError (f"Incorrect type of { name } . Got type: { type (mm_input )} " )
736- if isinstance (mm_input , torch .Tensor ):
737- return mm_input .reshape (- 1 , * mm_input .shape [2 :])
738-
739- if name == "input_values" :
740- max_length = max (tensor .shape [1 ] for tensor in mm_input )
741- padded_mm_input = [
742- torch .nn .functional .pad (tensor , (0 , max_length - tensor .shape [1 ]))
743- if tensor .shape [1 ] < max_length
744- else tensor
745- for tensor in mm_input
746- ]
747- return torch .concat (padded_mm_input )
748-
749- return torch .concat (mm_input )
750-
751739 def _parse_and_validate_audio_input (
752740 self , ** kwargs : object
753741 ) -> MiDashengLMAudioInputs | None :
@@ -756,24 +744,22 @@ def _parse_and_validate_audio_input(
756744
757745 if input_values is None :
758746 return None
759- input_values = self ._validate_and_reshape_mm_tensor (
760- input_values , "input_values"
761- )
762- audio_length = self ._validate_and_reshape_mm_tensor (
763- audio_length , "audio_length"
764- )
765- if not isinstance (input_values , (torch .Tensor , list )):
766- raise ValueError (
767- "Incorrect type of audio input features. "
768- f"Got type: { type (input_values )} "
747+
748+ if isinstance (input_values , list ):
749+ input_values = torch .nn .utils .rnn .pad_sequence (
750+ input_values ,
751+ batch_first = True ,
769752 )
770753
771754 return MiDashengLMAudioInputs (
772755 input_values = input_values ,
773756 audio_length = audio_length ,
774757 )
775758
776- def _process_audio_input (self , audio_input : MiDashengLMAudioInputs ) -> torch .Tensor :
759+ def _process_audio_input (
760+ self ,
761+ audio_input : MiDashengLMAudioInputs ,
762+ ) -> tuple [torch .Tensor , ...]:
777763 # Process audio through encoder and projector
778764 input_values = audio_input ["input_values" ]
779765 audio_length = audio_input ["audio_length" ]
@@ -783,17 +769,13 @@ def _process_audio_input(self, audio_input: MiDashengLMAudioInputs) -> torch.Ten
783769 audio_embeddings = audio_embeddings .to (audio_input ["input_values" ].dtype )
784770 batch_size , max_audio_tokens , embed_dim = audio_embeddings .shape
785771
786- audio_length_np = (
787- audio_length .cpu ().numpy ()
788- if isinstance (audio_length , torch .Tensor )
789- else audio_length
790- )
791772 audio_output_lengths = [
792773 max (1 , calculate_mel_frames_dasheng (int (length ))) # at least one frame
793- for length in audio_length_np
774+ for length in audio_length . tolist ()
794775 ]
795- audio_output_lengths = torch .tensor (audio_output_lengths ).to (
796- audio_embeddings .device
776+ audio_output_lengths = torch .tensor (
777+ audio_output_lengths ,
778+ device = audio_embeddings .device ,
797779 )
798780
799781 audio_feature_mask = torch .arange (
@@ -826,14 +808,6 @@ def forward(
826808 ) -> torch .Tensor | IntermediateTensors :
827809 if intermediate_tensors is not None :
828810 inputs_embeds = None
829- elif inputs_embeds is None :
830- multimodal_embeddings = self .get_multimodal_embeddings (** kwargs )
831- inputs_embeds = self .get_input_embeddings (
832- input_ids ,
833- multimodal_embeddings ,
834- is_multimodal = input_ids == self .config .audio_token_id ,
835- )
836- input_ids = None
837811
838812 return self .decoder .model (
839813 input_ids ,
0 commit comments