@@ -69,18 +69,16 @@ class UltravoxAudioFeatureInputs(TensorSchema):
6969 type : Literal ["audio_features" ]
7070 data : Annotated [
7171 Union [torch .Tensor , list [torch .Tensor ], list [list [torch .Tensor ]]],
72- TensorShape ("b " , "n" , " nmb" , "t" , dynamic_dims = { "n" } ),
72+ TensorShape ("bn " , "nmb" , "t" ),
7373 ]
74- lens : Annotated [
75- Union [torch .Tensor , list [torch .Tensor ]],
76- TensorShape ("b" , "n" , dynamic_dims = {"n" }),
77- ]
78- """Length of the audio frames. Used for attention mask in WhisperEncoder."""
79- token_len : Annotated [
80- Union [torch .Tensor , list [torch .Tensor ]],
81- TensorShape ("b" , "n" , dynamic_dims = {"n" }),
82- ]
83- """Length of the audio tokens. Used for flattening the audio features."""
74+ lens : Annotated [torch .Tensor , TensorShape ("bn" )]
75+ """
76+ Length of the audio frames per chunk. Used for attention mask in WhisperEncoder.
77+ """
78+ token_len : Annotated [torch .Tensor , TensorShape ("bn" )]
79+ """Length of the audio tokens per chunk. Used for flattening the audio features."""
80+ num_chunks : Annotated [torch .Tensor , TensorShape ("n" )]
81+ """Number of chunks per audio. Used for flattening the audio features."""
8482
8583
8684class UltravoxAudioEmbeddingInputs (TensorSchema ):
@@ -421,6 +419,8 @@ def forward(
421419 dummy_inputs = UltravoxDummyInputsBuilder ,
422420)
423421class UltravoxModel (nn .Module , SupportsMultiModal , SupportsPP , SupportsLoRA ):
422+ merge_by_field_config = True
423+
424424 packed_modules_mapping = {
425425 "qkv_proj" : ["q_proj" , "k_proj" , "v_proj" ],
426426 "gate_up_proj" : ["gate_proj" , "up_proj" ],
@@ -519,6 +519,7 @@ def _parse_and_validate_audio_input(
519519 audio_embeds = kwargs .pop ("audio_embeds" , None )
520520 audio_lens = kwargs .pop ("audio_lens" , None )
521521 audio_token_len = kwargs .pop ("audio_token_len" , None )
522+ audio_num_chunks = kwargs .pop ("audio_num_chunks" , None )
522523
523524 if audio_features is None and audio_embeds is None :
524525 return None
@@ -529,6 +530,7 @@ def _parse_and_validate_audio_input(
529530 data = audio_features ,
530531 lens = audio_lens ,
531532 token_len = audio_token_len ,
533+ num_chunks = audio_num_chunks ,
532534 )
533535
534536 if audio_embeds is not None :
@@ -547,9 +549,8 @@ def _process_audio_input(
547549 # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
548550 audio_features = pad_and_concat_to_dim3 (audio_input ["data" ])
549551
550- # [B1, B2] -> [B1+B2]
551- audio_lens = flatten_bn (audio_input ["lens" ], concat = True )
552- audio_token_len = flatten_bn (audio_input ["token_len" ], concat = True )
552+ audio_lens = audio_input ["lens" ]
553+ audio_token_len = audio_input ["token_len" ]
553554
554555 embeddings = self ._audio_features_to_embeddings (audio_features , audio_lens )
555556
@@ -568,7 +569,8 @@ def _process_audio_input(
568569
569570 # Return one tensor per input audio
570571 embed_lens = [
571- token_len_item .sum ().item () for token_len_item in audio_input ["token_len" ]
572+ chunk_lens .sum ().item ()
573+ for chunk_lens in audio_token_len .split (audio_input ["num_chunks" ].tolist ())
572574 ]
573575 return flattened_embeddings .split (embed_lens )
574576
@@ -663,6 +665,7 @@ def pad_and_concat_to_dim3(
663665 if features .ndim > 3 :
664666 # Flatten [B, N, 80, M] -> [B * N, 80, M]
665667 features = flatten_bn (features )
668+
666669 return features
667670
668671 features = [pad_and_concat_to_dim3 (f ) for f in features ]
0 commit comments