Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,16 @@ class UltravoxAudioFeatureInputs(TensorSchema):
type: Literal["audio_features"]
data: Annotated[
Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]],
TensorShape("b", "n", "nmb", "t", dynamic_dims={"n"}),
TensorShape("bn", "nmb", "t"),
]
lens: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("b", "n", dynamic_dims={"n"}),
]
"""Length of the audio frames. Used for attention mask in WhisperEncoder."""
token_len: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("b", "n", dynamic_dims={"n"}),
]
"""Length of the audio tokens. Used for flattening the audio features."""
lens: Annotated[torch.Tensor, TensorShape("bn")]
"""
Length of the audio frames per chunk. Used for attention mask in WhisperEncoder.
"""
token_len: Annotated[torch.Tensor, TensorShape("bn")]
"""Length of the audio tokens per chunk. Used for flattening the audio features."""
num_chunks: Annotated[torch.Tensor, TensorShape("n")]
"""Number of chunks per audio. Used for flattening the audio features."""


class UltravoxAudioEmbeddingInputs(TensorSchema):
Expand Down Expand Up @@ -421,6 +419,8 @@ def forward(
dummy_inputs=UltravoxDummyInputsBuilder,
)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
merge_by_field_config = True

packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
Expand Down Expand Up @@ -519,6 +519,7 @@ def _parse_and_validate_audio_input(
audio_embeds = kwargs.pop("audio_embeds", None)
audio_lens = kwargs.pop("audio_lens", None)
audio_token_len = kwargs.pop("audio_token_len", None)
audio_num_chunks = kwargs.pop("audio_num_chunks", None)

if audio_features is None and audio_embeds is None:
return None
Expand All @@ -529,6 +530,7 @@ def _parse_and_validate_audio_input(
data=audio_features,
lens=audio_lens,
token_len=audio_token_len,
num_chunks=audio_num_chunks,
)

if audio_embeds is not None:
Expand All @@ -547,9 +549,8 @@ def _process_audio_input(
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
audio_features = pad_and_concat_to_dim3(audio_input["data"])

# [B1, B2] -> [B1+B2]
audio_lens = flatten_bn(audio_input["lens"], concat=True)
audio_token_len = flatten_bn(audio_input["token_len"], concat=True)
audio_lens = audio_input["lens"]
audio_token_len = audio_input["token_len"]

embeddings = self._audio_features_to_embeddings(audio_features, audio_lens)

Expand All @@ -568,7 +569,8 @@ def _process_audio_input(

# Return one tensor per input audio
embed_lens = [
token_len_item.sum().item() for token_len_item in audio_input["token_len"]
chunk_lens.sum().item()
for chunk_lens in audio_token_len.split(audio_input["num_chunks"].tolist())
]
return flattened_embeddings.split(embed_lens)

Expand Down Expand Up @@ -663,6 +665,7 @@ def pad_and_concat_to_dim3(
if features.ndim > 3:
# Flatten [B, N, 80, M] -> [B * N, 80, M]
features = flatten_bn(features)

return features

features = [pad_and_concat_to_dim3(f) for f in features]
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/voxtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
)

from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
from .utils import init_vllm_registered_model, maybe_prefix

logger = init_logger(__name__)

Expand Down Expand Up @@ -337,6 +337,8 @@ def _get_data_parser(self) -> MultiModalDataParser:
class VoxtralForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
):
merge_by_field_config = True

supported_languages = ISO639_1_SUPPORTED_LANGS

packed_modules_mapping = {
Expand Down Expand Up @@ -445,7 +447,6 @@ def _parse_and_validate_audio_arrays(
f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}"
)

audio_arrays = flatten_bn(audio_arrays)
if isinstance(audio_arrays, torch.Tensor):
audio_arrays = list(audio_arrays.unbind(0))
return audio_arrays
Expand Down
16 changes: 8 additions & 8 deletions vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
Expand All @@ -51,6 +51,7 @@
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
Expand Down Expand Up @@ -135,7 +136,10 @@ class WhisperAudioInputs(TensorSchema):
- t: Time frames (M)
"""

input_features: Annotated[Optional[NestedTensors], TensorShape("b", "nmb", "t")]
input_features: Annotated[
Optional[list[torch.Tensor]],
TensorShape("b", "nmb", "t"),
]


class WhisperEncoderAttention(MultiHeadAttention):
Expand Down Expand Up @@ -781,6 +785,7 @@ def _get_prompt_updates(
class WhisperForConditionalGeneration(
nn.Module, SupportsTranscription, SupportsMultiModal
):
merge_by_field_config = True
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
Expand Down Expand Up @@ -936,12 +941,7 @@ def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInput
input_features = kwargs.pop("input_features", None)

if input_features is not None:
if not isinstance(input_features, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of audio features. "
f"Got type: {type(input_features)}"
)
input_features = torch.cat([feat.to(self.dtype) for feat in input_features])
input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)

return WhisperAudioInputs(input_features=input_features)

Expand Down
3 changes: 3 additions & 0 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,9 @@ def __init__(self, field: BaseMultiModalField, modality: str) -> None:
self.field = field
self.modality = modality

def __repr__(self) -> str:
return f"MultiModalFieldConfig(field={self.field}, modality={self.modality})"

def build_elems(
self,
key: str,
Expand Down