Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offline batch inference mode #82

Merged
merged 26 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __call__(self, features, *args, **kwargs):
}
for f in features
]
input_ids_len = torch.LongTensor([f["input_ids"].shape[-1] for f in features])
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
batch = super().__call__(features, *args, **kwargs)
if self.include_alt_fields:
alt_batch = super().__call__(alt_features, *args, **kwargs)
Expand All @@ -101,6 +102,11 @@ def __call__(self, features, *args, **kwargs):
batch["audio_values"] = torch.stack(
[F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
)
if self.tokenizer.padding_side == "left":
displacement = batch["input_ids"].shape[-1] - input_ids_len
batch["audio_token_start_idx"] += displacement.to(
batch["audio_token_start_idx"].device
)

return batch

Expand Down
28 changes: 27 additions & 1 deletion ultravox/inference/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,29 @@ def __init__(
None
)

assert self.tokenizer.padding_side == "left"

def batch_infer(
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
self,
samples: List[datasets.VoiceSample],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> base.InferenceGenerator:
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
inputs = [self._dataproc(s, batch=True) for s in samples]
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
data_collator = datasets.DataCollatorForSeq2SeqWithAudio(
tokenizer=self.tokenizer,
include_alt_fields=False,
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
)
tensors = data_collator(inputs)
input_len = tensors["input_ids"].shape[1]
output_batch = self._generate(tensors, max_tokens, temperature)
for output in output_batch:
output_tokens = output[input_len:]
output_text = self.tokenizer.decode(output_tokens, skip_special_tokens=True)
output_len = len(output_tokens)
output_text = base.VoiceOutput(output_text, input_len, output_len)
yield output_text

def update_conversation(
self,
past_messages: List[Dict[str, str]] = [],
Expand Down Expand Up @@ -118,7 +141,7 @@ def infer_stream(
yield base.InferenceStats(input_tokens, output_tokens)
thread.join()

def _dataproc(self, sample: datasets.VoiceSample):
def _dataproc(self, sample: datasets.VoiceSample, batch=False):
text_input = self.tokenizer.apply_chat_template(
sample.messages, add_generation_prompt=True, tokenize=False
)
Expand Down Expand Up @@ -152,6 +175,9 @@ def _dataproc(self, sample: datasets.VoiceSample):
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
if "audio_values" in inputs:
inputs["audio_values"] = inputs["audio_values"].to(dtype=self.dtype)
if batch:
for key, val in inputs.items():
inputs[key] = val.squeeze(0)
return inputs
liPatrick marked this conversation as resolved.
Show resolved Hide resolved

@torch.inference_mode()
Expand Down
64 changes: 44 additions & 20 deletions ultravox/tools/infer_tool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#!/usr/bin/env python

import argparse
import dataclasses
import json
import os
import time
from typing import IO, List, Optional
Expand All @@ -14,6 +12,7 @@
from ultravox.evaluation import eval
from ultravox.evaluation import eval_types
from ultravox.inference import base
from ultravox.inference import infer
from ultravox.tools import infer_api

# There are two default modes for this tool, agent mode and ASR mode.
Expand Down Expand Up @@ -77,6 +76,8 @@ class InferArgs:
verbose: bool = simple_parsing.field(default=False, alias="-v")
# JSON output
json: bool = simple_parsing.field(default=False)
# Batch size
batch_size: Optional[int] = simple_parsing.field(default=None, alias="-b")
liPatrick marked this conversation as resolved.
Show resolved Hide resolved

def __post_init__(self):
if self.prompt and self.prompt.startswith("@"):
Expand Down Expand Up @@ -190,25 +191,48 @@ def dataset_infer(inference: base.VoiceInference, args: InferArgs):
if args.seed is not None:
ds_args.shuffle_seed = args.seed
ds = datasets.create_dataset(args.data_sets[0], ds_args)
scores: List[float] = []
for i, sample in enumerate(datasets.Range(ds, args.num_samples)):
# Store the original question and answer for JSON output.
question_text = sample.audio_transcript
expected_answer = sample.messages[-1]["content"]
# Drop any assistant response from the sample.
sample.messages = sample.messages[:-1]
if not args.json:

if args.json and isinstance(inference, infer.LocalInference):
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
if not args.batch_size:
args.batch_size = 1

# TODO: Add multithreading support for preparing the batch.
start_time = time.time()
current_batch = []
for i, sample in enumerate(datasets.Range(ds, args.num_samples)):
current_batch.append(sample)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
if len(current_batch) == args.batch_size:
output = []
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
for sample in current_batch:
output.append(
{
"index": i,
"question_text": sample.audio_transcript,
"expected_answer": sample.messages[-1]["content"],
}
)
sample.messages = sample.messages[:-1]

output_batch = inference.batch_infer(
current_batch,
max_tokens=args.max_tokens,
temperature=args.temperature,
)
for i, output_text in enumerate(output_batch):
output[i]["output_text"] = output_text
print(output)
current_batch = []
print("Total time", time.time() - start_time)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved

else:
scores: List[float] = []
for i, sample in enumerate(datasets.Range(ds, args.num_samples)):
# Store the original question and answer for JSON output.
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
question_text = sample.audio_transcript
expected_answer = sample.messages[-1]["content"]
# Drop any assistant response from the sample.
sample.messages = sample.messages[:-1]
run_tui(i, inference, sample, args, expected_answer, scores)
else:
output = inference.infer(
sample, max_tokens=args.max_tokens, temperature=args.temperature
)
obj = {
"question": question_text,
"generated_answer": output.text,
"expected_answer": expected_answer,
}
print(json.dumps(obj))


def main(args: InferArgs):
Expand Down
Loading