2222# See the License for the specific language governing permissions and
2323# limitations under the License.
2424"""Inference-only MiDashengLM model compatible with HuggingFace weights."""
25+
2526import collections
2627import collections .abc
2728from collections .abc import Iterable , Mapping , Sequence
3132import torch
3233import torch .nn as nn
3334import torchaudio .functional as F
35+ from torch .nn .functional import scaled_dot_product_attention
3436from transformers import BatchFeature
3537
36- from vllm .attention .layer import MultiHeadAttention
3738from vllm .config import VllmConfig
3839from vllm .distributed import get_tensor_model_parallel_world_size
3940from vllm .model_executor .layers .activation import get_act_fn
@@ -146,15 +147,19 @@ def __init__(
146147 super ().__init__ ()
147148 out_features = out_features or in_features
148149 hidden_features = hidden_features or in_features
149- self .fc1 = ColumnParallelLinear (input_size = in_features ,
150- output_size = hidden_features ,
151- quant_config = quant_config ,
152- prefix = f"{ prefix } .fc1" )
150+ self .fc1 = ColumnParallelLinear (
151+ input_size = in_features ,
152+ output_size = hidden_features ,
153+ quant_config = quant_config ,
154+ prefix = f"{ prefix } .fc1" ,
155+ )
153156 self .act = get_act_fn ("gelu" )
154- self .fc2 = RowParallelLinear (input_size = hidden_features ,
155- output_size = out_features ,
156- quant_config = quant_config ,
157- prefix = f"{ prefix } .fc2" )
157+ self .fc2 = RowParallelLinear (
158+ input_size = hidden_features ,
159+ output_size = out_features ,
160+ quant_config = quant_config ,
161+ prefix = f"{ prefix } .fc2" ,
162+ )
158163
159164 def forward (self , x : torch .Tensor ) -> torch .Tensor :
160165 x , _ = self .fc1 (x )
@@ -170,7 +175,6 @@ def __init__(
170175 dim : int ,
171176 num_heads : int = 8 ,
172177 qkv_bias : bool = False ,
173- causal : bool = False ,
174178 quant_config : Optional [QuantizationConfig ] = None ,
175179 prefix : str = "" ,
176180 ):
@@ -204,33 +208,30 @@ def __init__(
204208 quant_config = quant_config ,
205209 prefix = f"{ prefix } .qkv" ,
206210 )
207- self .attn = MultiHeadAttention (
208- self .num_heads ,
209- self .head_dim ,
210- self .scale ,
211- num_kv_heads = self .num_kv_heads ,
212- )
213211 self .proj = RowParallelLinear (
214212 input_size = dim ,
215213 output_size = dim ,
216214 quant_config = quant_config ,
217215 prefix = f"{ prefix } .proj" ,
218216 )
219- self .causal = causal
220217
221218 def forward (self , x : torch .Tensor , mask : Optional [torch .Tensor ] = None ):
222219 B , N , C = x .shape
223220
224- qkv_out , _ = self .qkv (x )
225- q , k , v = qkv_out .split ([self .q_size , self .kv_size , self .kv_size ],
226- dim = - 1 )
227-
228- attn_out = self .attn (q , k , v )
229- C_local = attn_out .numel () // (B * N ) # C_local for parallel
230- attn_out = attn_out .view (B , N , C_local )
221+ qkv , _ = self .qkv (x )
222+ qkv = qkv .reshape (B , N , 3 , self .num_heads , C // self .num_heads )
223+ qkv = qkv .permute (2 , 0 , 3 , 1 , 4 )
224+ q , k , v = qkv .unbind (0 )
231225
232- x , _ = self .proj (attn_out )
226+ x = scaled_dot_product_attention (
227+ q ,
228+ k ,
229+ v ,
230+ attn_mask = mask [:, None , None , :] if mask is not None else None ,
231+ )
233232
233+ x = x .transpose (1 , 2 ).reshape (B , N , C )
234+ x , _ = self .proj (x )
234235 return x
235236
236237
@@ -462,14 +463,16 @@ def __init__(
462463 quant_config = quant_config ,
463464 prefix = f"{ prefix } .net.0" ,
464465 return_bias = False ,
465- ), get_act_fn ("gelu" ),
466+ ),
467+ get_act_fn ("gelu" ),
466468 RowParallelLinear (
467469 input_size = out_dim ,
468470 output_size = out_dim ,
469471 quant_config = quant_config ,
470472 prefix = f"{ prefix } .net.2" ,
471473 return_bias = False ,
472- ))
474+ ),
475+ )
473476
474477 def forward (self , x , mask = None ):
475478 batch_size , seq_len , dim = x .shape
@@ -566,9 +569,12 @@ def _call_hf_processor(
566569 # + Padding
567570 min_audio_len = self .info .get_min_audio_len ()
568571 processed_audios = [
569- np .pad (audio , (0 , min_audio_len - audio .shape [- 1 ]),
570- mode = 'constant' ,
571- constant_values = 0 ) if isinstance (audio , np .ndarray )
572+ np .pad (
573+ audio ,
574+ (0 , min_audio_len - audio .shape [- 1 ]),
575+ mode = "constant" ,
576+ constant_values = 0 ,
577+ ) if isinstance (audio , np .ndarray )
572578 and audio .shape [- 1 ] < min_audio_len else audio for audio in audios
573579 ]
574580
@@ -617,8 +623,8 @@ def _get_prompt_updates(
617623 if audio_length is None :
618624 audio_output_lengths = []
619625 else :
620- audio_length_np = audio_length .cpu ().numpy () if isinstance (
621- audio_length , torch .Tensor ) else audio_length
626+ audio_length_np = ( audio_length .cpu ().numpy () if isinstance (
627+ audio_length , torch .Tensor ) else audio_length )
622628 audio_output_lengths = [
623629 max (1 , calculate_mel_frames_dasheng (
624630 int (length ))) # at least one frame
@@ -703,8 +709,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
703709 def _validate_and_reshape_mm_tensor (self , mm_input : object ,
704710 name : str ) -> torch .Tensor :
705711 if not isinstance (mm_input , (torch .Tensor , list )):
706- raise ValueError (f"Incorrect type of { name } . "
707- f" Got type: { type (mm_input )} " )
712+ raise ValueError (
713+ f"Incorrect type of { name } . Got type: { type (mm_input )} " )
708714 if isinstance (mm_input , torch .Tensor ):
709715 return mm_input .reshape (- 1 , * mm_input .shape [2 :])
710716
@@ -753,8 +759,8 @@ def _process_audio_input(
753759 audio_input ["input_values" ].dtype )
754760 batch_size , max_audio_tokens , embed_dim = audio_embeddings .shape
755761
756- audio_length_np = audio_length .cpu ().numpy () if isinstance (
757- audio_length , torch .Tensor ) else audio_length
762+ audio_length_np = ( audio_length .cpu ().numpy () if isinstance (
763+ audio_length , torch .Tensor ) else audio_length )
758764 audio_output_lengths = [
759765 max (1 , calculate_mel_frames_dasheng (
760766 int (length ))) # at least one frame
@@ -763,11 +769,11 @@ def _process_audio_input(
763769 audio_output_lengths = torch .tensor (audio_output_lengths ).to (
764770 audio_embeddings .device )
765771
766- audio_feature_mask = ( torch .arange (
772+ audio_feature_mask = torch .arange (
767773 max_audio_tokens ,
768774 device = audio_embeddings .device ).unsqueeze (0 ).expand (
769- batch_size , max_audio_tokens )
770- < audio_output_lengths .unsqueeze (1 ) )
775+ batch_size ,
776+ max_audio_tokens ) < audio_output_lengths .unsqueeze (1 )
771777
772778 masked_audio_features = audio_embeddings [audio_feature_mask ].view (
773779 - 1 , embed_dim )
@@ -805,10 +811,12 @@ def forward(
805811 )
806812 input_ids = None
807813
808- return self .decoder .model (input_ids ,
809- positions ,
810- intermediate_tensors ,
811- inputs_embeds = inputs_embeds )
814+ return self .decoder .model (
815+ input_ids ,
816+ positions ,
817+ intermediate_tensors ,
818+ inputs_embeds = inputs_embeds ,
819+ )
812820
813821 def compute_logits (
814822 self ,
0 commit comments