Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Update benchmark_throughput.py to support image input #9851

Merged
merged 5 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,14 @@ You can download the dataset by running:
```bash
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
```

## Downloading the ShareGPT4V dataset

The json file refers to several image datasets (coco, llava, etc.). The benchmark scripts
will ignore a datapoint if the referred image is missing.
```bash
wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/resolve/main/sharegpt4v_instruct_gpt4-vision_cap100k.json
mkdir coco -p
wget http://images.cocodataset.org/zips/train2017.zip -O coco/train2017.zip
unzip coco/train2017.zip -d coco/
```
82 changes: 64 additions & 18 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import uvloop
from PIL import Image
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
Expand Down Expand Up @@ -38,12 +39,33 @@ class SampleRequest:
multi_modal_data: Optional[MultiModalDataDict] = None


def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[SampleRequest]:
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
"""Prepend and append special tokens around the question to form a prompt.

Args:
question: The input question text to wrap with special tokens
model: The name of the model being used, to determine which special
tokens to add

Returns:
The formatted prompt string with appropriate special tokens for the
model

Raises:
ValueError: If an unsupported model name is provided
"""
model = model.lower()
if "pixtral" in model:
return f"<s>[INST]{question}\n[IMG][/INST]"
raise ValueError(f"Unsupported model {model}")


def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]:
dataset_path: str = args.dataset
num_requests: int = args.num_prompts
fixed_output_len: Optional[int] = args.output_len
model: str = args.model
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

Expand All @@ -52,23 +74,36 @@ def sample_requests(
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]

# Shuffle the dataset.
random.shuffle(dataset)

# Filter out sequences that are too long or too short
filtered_dataset: List[SampleRequest] = []
for i in range(len(dataset)):
for data in dataset:
if len(filtered_dataset) == num_requests:
break

# Only keep the first two turns of each conversation.
prompt = data["conversations"][0]["value"]
completion = data["conversations"][1]["value"]

multi_modal_data: Optional[MultiModalDataDict] = None
if "image" in data:
multi_modal_data = multi_modal_data or {}
image_path = data["image"]
# TODO(vllm-project/vllm/issues/9778): Support multiple images.
assert isinstance(image_path,
str), "Only support single image input"
try:
multi_modal_data["image"] = Image.open(image_path).convert(
"RGB")
except FileNotFoundError:
# Ignore datapoint where asset is missing
continue
prompt = _get_prompt_for_image_model(question=prompt, model=model)

# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
Expand All @@ -82,7 +117,8 @@ def sample_requests(
filtered_dataset.append(
SampleRequest(prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len))
expected_output_len=output_len,
multi_modal_data=multi_modal_data))

return filtered_dataset

Expand All @@ -99,7 +135,9 @@ def run_vllm(
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for request in requests:
prompts.append(TextPrompt(prompt=request.prompt))
prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append(
SamplingParams(
n=n,
Expand Down Expand Up @@ -148,7 +186,9 @@ async def run_vllm_async(
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for request in requests:
prompts.append(TextPrompt(prompt=request.prompt))
prompts.append(
TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data))
sampling_params.append(
SamplingParams(
n=n,
Expand Down Expand Up @@ -272,9 +312,10 @@ def main(args: argparse.Namespace):
for _ in range(args.num_prompts)
]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
requests = sample_requests(tokenizer, args)

is_multi_modal = any(request.multi_modal_data is not None
for request in requests)
if args.backend == "vllm":
if args.async_engine:
elapsed_time = uvloop.run(
Expand All @@ -300,6 +341,11 @@ def main(args: argparse.Namespace):
for request in requests)
total_output_tokens = sum(request.expected_output_len
for request in requests)
if is_multi_modal:
print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
"following metrics is not accurate because image tokens are not "
"counted. See vllm-project/vllm/issues/9778 for details.")
# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
Expand Down