|
38 | 38 | from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, |
39 | 39 | SequenceData) |
40 | 40 | from vllm.transformers_utils.configs.ultravox import UltravoxConfig |
| 41 | +from vllm.utils import is_list_of |
41 | 42 |
|
42 | 43 | from .interfaces import SupportsMultiModal, SupportsPP |
43 | 44 |
|
@@ -119,6 +120,10 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): |
119 | 120 | if not isinstance(data, list): |
120 | 121 | data = [data] |
121 | 122 |
|
| 123 | + # If the audio inputs are embeddings, no need for preprocessing |
| 124 | + if is_list_of(data, torch.Tensor, check="all"): |
| 125 | + return MultiModalInputs({"audio_embeds": data}) |
| 126 | + |
122 | 127 | audio_features = [] |
123 | 128 | for audio_input in data: |
124 | 129 | if not isinstance(audio_input, tuple): |
@@ -165,25 +170,30 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): |
165 | 170 | audios = [audios] |
166 | 171 |
|
167 | 172 | audio_token_counts = [] |
168 | | - for audio_data, sample_rate in audios: |
169 | | - audio_length = audio_data.shape[0] |
170 | | - if sample_rate != feature_extractor.sampling_rate: |
171 | | - # Account for resampling. |
172 | | - adjustment = feature_extractor.sampling_rate / sample_rate |
173 | | - audio_length = math.ceil(adjustment * audio_length) |
174 | | - |
175 | | - feature_extractor_output_length = math.ceil( |
176 | | - (audio_length - (feature_extractor.hop_length - 1)) / |
177 | | - feature_extractor.hop_length) |
178 | | - |
179 | | - uv_config = ctx.get_hf_config(UltravoxConfig) |
180 | | - audio_num_tokens = min( |
181 | | - max( |
182 | | - 1, |
183 | | - math.ceil(feature_extractor_output_length / |
184 | | - (uv_config.stack_factor * 2))), |
185 | | - get_ultravox_max_audio_tokens(ctx)) |
186 | | - audio_token_counts.append(audio_num_tokens) |
| 173 | + for audio in audios: |
| 174 | + if isinstance(audio, torch.Tensor): |
| 175 | + audio_num_tokens = audio.shape[1] |
| 176 | + audio_token_counts.append(audio_num_tokens) |
| 177 | + else: |
| 178 | + audio_data, sample_rate = audio |
| 179 | + audio_length = audio_data.shape[0] |
| 180 | + if sample_rate != feature_extractor.sampling_rate: |
| 181 | + # Account for resampling. |
| 182 | + adjustment = feature_extractor.sampling_rate / sample_rate |
| 183 | + audio_length = math.ceil(adjustment * audio_length) |
| 184 | + |
| 185 | + feature_extractor_output_length = math.ceil( |
| 186 | + (audio_length - (feature_extractor.hop_length - 1)) / |
| 187 | + feature_extractor.hop_length) |
| 188 | + |
| 189 | + uv_config = ctx.get_hf_config(UltravoxConfig) |
| 190 | + audio_num_tokens = min( |
| 191 | + max( |
| 192 | + 1, |
| 193 | + math.ceil(feature_extractor_output_length / |
| 194 | + (uv_config.stack_factor * 2))), |
| 195 | + get_ultravox_max_audio_tokens(ctx)) |
| 196 | + audio_token_counts.append(audio_num_tokens) |
187 | 197 |
|
188 | 198 | tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) |
189 | 199 |
|
|
0 commit comments