Skip to content

Commit

Permalink
[Model] Add Ultravox support for multiple audio chunks (vllm-project#…
Browse files Browse the repository at this point in the history
  • Loading branch information
petersalas authored and dsikka committed Sep 5, 2024
1 parent db89b14 commit cfd3280
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 115 deletions.
58 changes: 34 additions & 24 deletions examples/offline_inference_audio_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,33 @@
from vllm.assets.audio import AudioAsset
from vllm.utils import FlexibleArgumentParser

# Input audio and question
audio_and_sample_rate = AudioAsset("mary_had_lamb").audio_and_sample_rate
question = "What is recited in the audio?"
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
question_per_audio_count = [
"What is recited in the audio?",
"What sport and what nursery rhyme are referenced?"
]


# Ultravox 0.3
def run_ultravox(question):
def run_ultravox(question, audio_count):
model_name = "fixie-ai/ultravox-v0_3"

tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{
'role': 'user',
'content': f"<|reserved_special_token_0|>\n{question}"
'role':
'user',
'content':
"<|reserved_special_token_0|>\n" * audio_count + question
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

llm = LLM(model=model_name)
llm = LLM(model=model_name,
enforce_eager=True,
enable_chunked_prefill=False,
max_model_len=8192,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -44,7 +52,9 @@ def main(args):
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")

llm, prompt, stop_token_ids = model_example_map[model](question)
audio_count = args.num_audios
llm, prompt, stop_token_ids = model_example_map[model](
question_per_audio_count[audio_count - 1], audio_count)

# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
Expand All @@ -53,23 +63,18 @@ def main(args):
stop_token_ids=stop_token_ids)

assert args.num_prompts > 0
if args.num_prompts == 1:
# Single inference
inputs = {
"prompt": prompt,
"multi_modal_data": {
"audio": audio_and_sample_rate
},
}

else:
inputs = {
"prompt": prompt,
"multi_modal_data": {
"audio": [
asset.audio_and_sample_rate
for asset in audio_assets[:audio_count]
]
},
}
if args.num_prompts > 1:
# Batch inference
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"audio": audio_and_sample_rate
},
} for _ in range(args.num_prompts)]
inputs = [inputs] * args.num_prompts

outputs = llm.generate(inputs, sampling_params=sampling_params)

Expand All @@ -92,6 +97,11 @@ def main(args):
type=int,
default=1,
help='Number of prompts to run.')
parser.add_argument("--num-audios",
type=int,
default=1,
choices=[1, 2],
help="Number of audio items per prompt.")

args = parser.parse_args()
main(args)
103 changes: 77 additions & 26 deletions tests/models/test_ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,32 @@

AudioTuple = Tuple[np.ndarray, int]

VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
HF_PLACEHOLDER = "<|audio|>"


@pytest.fixture(scope="session")
def audio_and_sample_rate():
def audio_assets():
from vllm.assets.audio import AudioAsset
return AudioAsset("mary_had_lamb").audio_and_sample_rate
return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]


@pytest.fixture
def prompts_and_audios(audio_and_sample_rate):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call"))
def audio(request):
from vllm.assets.audio import AudioAsset
return AudioAsset(request.param)

vllm_placeholder = "<|reserved_special_token_0|>"
hf_placeholder = "<|audio|>"

question = "What's in the audio?"
vllm_prompt = tokenizer.apply_chat_template(
[{
'role': 'user',
'content': f"{vllm_placeholder}\n{question}"
}],
tokenize=False,
add_generation_prompt=True)
hf_prompt = tokenizer.apply_chat_template(
[{
'role': 'user',
'content': f"{hf_placeholder}\n{question}"
}],
tokenize=False,
add_generation_prompt=True)
def _get_prompt(audio_count, question, placeholder):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
placeholder = f"{placeholder}\n" * audio_count

return [(vllm_prompt, hf_prompt, audio_and_sample_rate)]
return tokenizer.apply_chat_template([{
'role': 'user',
'content': f"{placeholder}{question}"
}],
tokenize=False,
add_generation_prompt=True)


def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Expand Down Expand Up @@ -134,15 +129,71 @@ def process(hf_inputs: BatchEncoding):
)


def run_multi_audio_test(
vllm_runner: Type[VllmRunner],
prompts_and_audios: List[Tuple[str, List[AudioTuple]]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
limit_mm_per_prompt={
"audio":
max((len(audio) for _, audio in prompts_and_audios))
}) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
[prompt for prompt, _ in prompts_and_audios],
max_tokens,
num_logprobs=num_logprobs,
audios=[audios for _, audios in prompts_and_audios])

# The HuggingFace model doesn't support multiple audios yet, so
# just assert that some tokens were generated.
assert all(tokens for tokens, *_ in vllm_outputs)


@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, prompts_and_audios, dtype: str,
max_tokens: int, num_logprobs: int) -> None:
def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
num_logprobs: int) -> None:

vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER)
hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
run_test(
hf_runner,
vllm_runner,
prompts_and_audios,
[(vllm_prompt, hf_prompt, audio.audio_and_sample_rate)],
MODEL_NAME,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
max_tokens: int,
num_logprobs: int) -> None:

vllm_prompt = _get_prompt(len(audio_assets),
"Describe each of the audios above.",
VLLM_PLACEHOLDER)
run_multi_audio_test(
vllm_runner,
[(vllm_prompt, [audio.audio_and_sample_rate
for audio in audio_assets])],
MODEL_NAME,
dtype=dtype,
max_tokens=max_tokens,
Expand Down
Loading

0 comments on commit cfd3280

Please sign in to comment.