2222# See the License for the specific language governing permissions and
2323# limitations under the License.
2424"""Inference-only MiDashengLM model compatible with HuggingFace weights."""
25+
2526import collections
2627import collections .abc
2728from collections .abc import Iterable , Mapping , Sequence
3031import numpy as np
3132import torch
3233import torch .nn as nn
33- import torchaudio .transforms as audio_transforms
34+ import torchaudio .functional as F
35+ from torch .nn .functional import scaled_dot_product_attention
3436from transformers import BatchFeature
3537
36- from vllm .attention .layer import MultiHeadAttention
3738from vllm .config import VllmConfig
3839from vllm .distributed import get_tensor_model_parallel_world_size
3940from vllm .model_executor .layers .activation import get_act_fn
4041from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
4142 QKVParallelLinear ,
4243 RowParallelLinear )
4344from vllm .model_executor .layers .quantization import QuantizationConfig
44- from vllm .model_executor .model_loader .utils import set_default_torch_dtype
4545from vllm .multimodal import MULTIMODAL_REGISTRY
4646from vllm .multimodal .inputs import (MultiModalDataDict , MultiModalFieldConfig ,
4747 MultiModalKwargsItems )
@@ -147,15 +147,19 @@ def __init__(
147147 super ().__init__ ()
148148 out_features = out_features or in_features
149149 hidden_features = hidden_features or in_features
150- self .fc1 = ColumnParallelLinear (input_size = in_features ,
151- output_size = hidden_features ,
152- quant_config = quant_config ,
153- prefix = f"{ prefix } .fc1" )
150+ self .fc1 = ColumnParallelLinear (
151+ input_size = in_features ,
152+ output_size = hidden_features ,
153+ quant_config = quant_config ,
154+ prefix = f"{ prefix } .fc1" ,
155+ )
154156 self .act = get_act_fn ("gelu" )
155- self .fc2 = RowParallelLinear (input_size = hidden_features ,
156- output_size = out_features ,
157- quant_config = quant_config ,
158- prefix = f"{ prefix } .fc2" )
157+ self .fc2 = RowParallelLinear (
158+ input_size = hidden_features ,
159+ output_size = out_features ,
160+ quant_config = quant_config ,
161+ prefix = f"{ prefix } .fc2" ,
162+ )
159163
160164 def forward (self , x : torch .Tensor ) -> torch .Tensor :
161165 x , _ = self .fc1 (x )
@@ -171,7 +175,6 @@ def __init__(
171175 dim : int ,
172176 num_heads : int = 8 ,
173177 qkv_bias : bool = False ,
174- causal : bool = False ,
175178 quant_config : Optional [QuantizationConfig ] = None ,
176179 prefix : str = "" ,
177180 ):
@@ -205,33 +208,30 @@ def __init__(
205208 quant_config = quant_config ,
206209 prefix = f"{ prefix } .qkv" ,
207210 )
208- self .attn = MultiHeadAttention (
209- self .num_heads ,
210- self .head_dim ,
211- self .scale ,
212- num_kv_heads = self .num_kv_heads ,
213- )
214211 self .proj = RowParallelLinear (
215212 input_size = dim ,
216213 output_size = dim ,
217214 quant_config = quant_config ,
218215 prefix = f"{ prefix } .proj" ,
219216 )
220- self .causal = causal
221217
222218 def forward (self , x : torch .Tensor , mask : Optional [torch .Tensor ] = None ):
223219 B , N , C = x .shape
224220
225- qkv_out , _ = self .qkv (x )
226- q , k , v = qkv_out .split ([self .q_size , self .kv_size , self .kv_size ],
227- dim = - 1 )
221+ qkv , _ = self .qkv (x )
222+ qkv = qkv .reshape (B , N , 3 , self .num_heads , C // self .num_heads )
223+ qkv = qkv .permute (2 , 0 , 3 , 1 , 4 )
224+ q , k , v = qkv .unbind (0 )
228225
229- attn_out = self .attn (q , k , v )
230- C_local = attn_out .numel () // (B * N ) # C_local for parallel
231- attn_out = attn_out .view (B , N , C_local )
232-
233- x , _ = self .proj (attn_out )
226+ x = scaled_dot_product_attention (
227+ q ,
228+ k ,
229+ v ,
230+ attn_mask = mask [:, None , None , :] if mask is not None else None ,
231+ )
234232
233+ x = x .transpose (1 , 2 ).reshape (B , N , C )
234+ x , _ = self .proj (x )
235235 return x
236236
237237
@@ -280,6 +280,63 @@ def forward(
280280 return x
281281
282282
283+ class DashengFrontend (nn .Module ):
284+
285+ def __init__ (self , config : DashengConfig ):
286+ super ().__init__ ()
287+ self .config = config
288+
289+ spectrogram_window = torch .hann_window (self .config .win_length )
290+ self .register_buffer (
291+ "spectrogram_window" ,
292+ spectrogram_window ,
293+ persistent = False ,
294+ )
295+ self .spectrogram_window : torch .Tensor
296+
297+ melscale_fbanks = F .melscale_fbanks (
298+ n_freqs = self .config .n_fft // 2 + 1 ,
299+ f_min = self .config .f_min ,
300+ f_max = self .config .f_max ,
301+ n_mels = self .config .n_mels ,
302+ sample_rate = self .config .sample_rate ,
303+ )
304+ self .register_buffer ("melscale_fbanks" ,
305+ melscale_fbanks ,
306+ persistent = False )
307+ self .melscale_fbanks : torch .Tensor
308+
309+ def forward (self , waveform : torch .Tensor ) -> torch .Tensor :
310+ spectrogram = F .spectrogram (
311+ waveform = waveform .to (torch .float32 ),
312+ pad = 0 ,
313+ window = self .spectrogram_window ,
314+ n_fft = self .config .n_fft ,
315+ hop_length = self .config .hop_length ,
316+ win_length = self .config .win_length ,
317+ power = 2 ,
318+ normalized = False ,
319+ center = self .config .center ,
320+ )
321+ mel_spectrogram = (
322+ spectrogram .mT @ self .melscale_fbanks .to (torch .float32 )).mT
323+ # x has shape [batch, freq, time].
324+ # F.amplitude_to_DB accepts inputs shaped as:
325+ # - [freq, time]
326+ # - [channel, freq, time]
327+ # - [..., channel, freq, time]
328+ # Here we insert a channel dimension of size 1 before calling it,
329+ # then remove that extra dimension afterward.
330+ log_mel_spectrogram = F .amplitude_to_DB (
331+ mel_spectrogram .unsqueeze (1 ),
332+ multiplier = 10 ,
333+ amin = 1e-10 ,
334+ db_multiplier = 0 ,
335+ top_db = 120 ,
336+ ).squeeze (1 )
337+ return log_mel_spectrogram .to (waveform .dtype )
338+
339+
283340class DashengAudioTransformer (nn .Module ):
284341
285342 def __init__ (
@@ -293,7 +350,7 @@ def __init__(
293350 self .target_length = config .target_length
294351 self .hop_length = config .hop_length
295352
296- self ._init_front_end (config )
353+ self .front_end = DashengFrontend (config )
297354
298355 self .init_bn = nn .BatchNorm2d (config .n_mels , momentum = 0.01 )
299356
@@ -318,34 +375,10 @@ def __init__(
318375 qkv_bias = config .qkv_bias ,
319376 init_values = config .init_values ,
320377 quant_config = quant_config ,
321- prefix = f"{ prefix } .block { i } " ,
378+ prefix = f"{ prefix } .blocks. { i } " ,
322379 ) for i in range (config .depth ))
323380 self .norm = nn .LayerNorm (config .embed_dim , eps = 1e-6 )
324381
325- def _init_front_end (self , config ):
326- with set_default_torch_dtype (torch .float32 ):
327- self .front_end = nn .Sequential (
328- audio_transforms .MelSpectrogram (
329- f_min = config .f_min ,
330- f_max = config .f_max ,
331- center = config .center ,
332- win_length = config .win_length ,
333- hop_length = config .hop_length ,
334- sample_rate = config .sample_rate ,
335- n_fft = config .n_fft ,
336- n_mels = config .n_mels ,
337- ),
338- audio_transforms .AmplitudeToDB (top_db = 120 ),
339- )
340-
341- mel_spectrogram = self .front_end [0 ]
342- fb = mel_spectrogram .mel_scale .fb
343- win = mel_spectrogram .spectrogram .window
344- mel_spectrogram .mel_scale .fb = fb .to (torch .bfloat16 ).to (
345- torch .float32 )
346- mel_spectrogram .spectrogram .window = win .to (torch .bfloat16 ).to (
347- torch .float32 )
348-
349382 def forward_features (
350383 self ,
351384 x : torch .Tensor ,
@@ -430,14 +463,16 @@ def __init__(
430463 quant_config = quant_config ,
431464 prefix = f"{ prefix } .net.0" ,
432465 return_bias = False ,
433- ), get_act_fn ("gelu" ),
466+ ),
467+ get_act_fn ("gelu" ),
434468 RowParallelLinear (
435469 input_size = out_dim ,
436470 output_size = out_dim ,
437471 quant_config = quant_config ,
438472 prefix = f"{ prefix } .net.2" ,
439473 return_bias = False ,
440- ))
474+ ),
475+ )
441476
442477 def forward (self , x , mask = None ):
443478 batch_size , seq_len , dim = x .shape
@@ -534,9 +569,12 @@ def _call_hf_processor(
534569 # + Padding
535570 min_audio_len = self .info .get_min_audio_len ()
536571 processed_audios = [
537- np .pad (audio , (0 , min_audio_len - audio .shape [- 1 ]),
538- mode = 'constant' ,
539- constant_values = 0 ) if isinstance (audio , np .ndarray )
572+ np .pad (
573+ audio ,
574+ (0 , min_audio_len - audio .shape [- 1 ]),
575+ mode = "constant" ,
576+ constant_values = 0 ,
577+ ) if isinstance (audio , np .ndarray )
540578 and audio .shape [- 1 ] < min_audio_len else audio for audio in audios
541579 ]
542580
@@ -585,8 +623,8 @@ def _get_prompt_updates(
585623 if audio_length is None :
586624 audio_output_lengths = []
587625 else :
588- audio_length_np = audio_length .cpu ().numpy () if isinstance (
589- audio_length , torch .Tensor ) else audio_length
626+ audio_length_np = ( audio_length .cpu ().numpy () if isinstance (
627+ audio_length , torch .Tensor ) else audio_length )
590628 audio_output_lengths = [
591629 max (1 , calculate_mel_frames_dasheng (
592630 int (length ))) # at least one frame
@@ -617,6 +655,17 @@ def get_replacement_midashenglm(item_idx: int):
617655 dummy_inputs = MiDashengLMDummyInputsBuilder ,
618656)
619657class MiDashengLMModel (nn .Module , SupportsMultiModal , SupportsPP ):
658+ packed_modules_mapping = {
659+ "qkv_proj" : [
660+ "q_proj" ,
661+ "k_proj" ,
662+ "v_proj" ,
663+ ],
664+ "gate_up_proj" : [
665+ "gate_proj" ,
666+ "up_proj" ,
667+ ],
668+ }
620669
621670 @classmethod
622671 def get_placeholder_str (cls , modality : str , i : int ) -> Optional [str ]:
@@ -660,8 +709,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
660709 def _validate_and_reshape_mm_tensor (self , mm_input : object ,
661710 name : str ) -> torch .Tensor :
662711 if not isinstance (mm_input , (torch .Tensor , list )):
663- raise ValueError (f"Incorrect type of { name } . "
664- f" Got type: { type (mm_input )} " )
712+ raise ValueError (
713+ f"Incorrect type of { name } . Got type: { type (mm_input )} " )
665714 if isinstance (mm_input , torch .Tensor ):
666715 return mm_input .reshape (- 1 , * mm_input .shape [2 :])
667716
@@ -710,8 +759,8 @@ def _process_audio_input(
710759 audio_input ["input_values" ].dtype )
711760 batch_size , max_audio_tokens , embed_dim = audio_embeddings .shape
712761
713- audio_length_np = audio_length .cpu ().numpy () if isinstance (
714- audio_length , torch .Tensor ) else audio_length
762+ audio_length_np = ( audio_length .cpu ().numpy () if isinstance (
763+ audio_length , torch .Tensor ) else audio_length )
715764 audio_output_lengths = [
716765 max (1 , calculate_mel_frames_dasheng (
717766 int (length ))) # at least one frame
@@ -720,11 +769,11 @@ def _process_audio_input(
720769 audio_output_lengths = torch .tensor (audio_output_lengths ).to (
721770 audio_embeddings .device )
722771
723- audio_feature_mask = ( torch .arange (
772+ audio_feature_mask = torch .arange (
724773 max_audio_tokens ,
725774 device = audio_embeddings .device ).unsqueeze (0 ).expand (
726- batch_size , max_audio_tokens )
727- < audio_output_lengths .unsqueeze (1 ) )
775+ batch_size ,
776+ max_audio_tokens ) < audio_output_lengths .unsqueeze (1 )
728777
729778 masked_audio_features = audio_embeddings [audio_feature_mask ].view (
730779 - 1 , embed_dim )
@@ -762,10 +811,12 @@ def forward(
762811 )
763812 input_ids = None
764813
765- return self .decoder .model (input_ids ,
766- positions ,
767- intermediate_tensors ,
768- inputs_embeds = inputs_embeds )
814+ return self .decoder .model (
815+ input_ids ,
816+ positions ,
817+ intermediate_tensors ,
818+ inputs_embeds = inputs_embeds ,
819+ )
769820
770821 def compute_logits (
771822 self ,
0 commit comments