Skip to content

Commit da9c6a4

Browse files
maxdebayserxuebwang-amd
authored andcommitted
Add support for the /rerank endpoint in vllm bench serve (vllm-project#26602)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 1518791 commit da9c6a4

File tree

3 files changed

+218
-6
lines changed

3 files changed

+218
-6
lines changed

docs/contributing/benchmarks.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ th {
3535
| Sonnet (deprecated) ||| Local file: `benchmarks/sonnet.txt` |
3636
| Random ||| `synthetic` |
3737
| RandomMultiModal (Image/Video) | 🟡 | 🚧 | `synthetic` |
38+
| RandomForReranking ||| `synthetic` |
3839
| Prefix Repetition ||| `synthetic` |
3940
| HuggingFace-VisionArena ||| `lmarena-ai/VisionArena-Chat` |
4041
| HuggingFace-MMVU ||| `yale-nlp/MMVU` |
@@ -878,6 +879,51 @@ vllm bench serve \
878879

879880
</details>
880881

882+
#### Reranker Benchmark
883+
884+
Benchmark the performance of rerank requests in vLLM.
885+
886+
<details class="admonition abstract" markdown="1">
887+
<summary>Show more</summary>
888+
889+
Unlike generative models which use Completions API or Chat Completions API,
890+
you should set `--backend vllm-rerank` and `--endpoint /v1/rerank` to use the Reranker API.
891+
892+
For reranking, the only supported dataset is `--dataset-name random-rerank`
893+
894+
Start the server:
895+
896+
```bash
897+
vllm serve BAAI/bge-reranker-v2-m3
898+
```
899+
900+
Run the benchmark:
901+
902+
```bash
903+
vllm bench serve \
904+
--model BAAI/bge-reranker-v2-m3 \
905+
--backend vllm-rerank \
906+
--endpoint /v1/rerank \
907+
--dataset-name random-rerank \
908+
--tokenizer BAAI/bge-reranker-v2-m3 \
909+
--random-input-len 512 \
910+
--num-prompts 10 \
911+
--random-batch-size 5
912+
```
913+
914+
For reranker models, this will create `num_prompts / random_batch_size` requests with
915+
`random_batch_size` "documents" where each one has close to `random_input_len` tokens.
916+
In the example above, this results in 2 rerank requests with 5 "documents" each where
917+
each document has close to 512 tokens.
918+
919+
Please note that the `/v1/rerank` is also supported by embedding models. So if you're running
920+
with an embedding model, also set `--no_reranker`. Because in this case the query is
921+
treated as a individual prompt by the server, here we send `random_batch_size - 1` documents
922+
to account for the extra prompt which is the query. The token accounting to report the
923+
throughput numbers correctly is also adjusted.
924+
925+
</details>
926+
881927
[](){ #performance-benchmarks }
882928

883929
## Performance Benchmarks

vllm/benchmarks/datasets.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,7 @@ def get_sampling_params(
572572
# Ensure the lower bound for output length is at least 1 to
573573
# prevent sampling 0 tokens.
574574
output_low = max(output_low, 1)
575+
output_high = max(output_high, 1)
575576

576577
if input_low > input_high:
577578
raise ValueError(
@@ -638,6 +639,112 @@ def generate_token_sequence(
638639
return prompt, total_input_len, token_mismatch
639640

640641

642+
# -----------------------------------------------------------------------------
643+
# Random Dataset Implementation (Synthetic Data)
644+
# -----------------------------------------------------------------------------
645+
646+
647+
class RandomDatasetForReranking(RandomDataset):
648+
"""
649+
Random dataset specialized for the needs of scoring:
650+
- Batches of inputs
651+
- Inputs composed of pairs
652+
"""
653+
654+
def __init__(self, **kwargs) -> None:
655+
super().__init__(**kwargs)
656+
657+
def sample(
658+
self,
659+
tokenizer: PreTrainedTokenizerBase,
660+
num_requests: int,
661+
request_id_prefix: str = "",
662+
range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
663+
input_len: int = RandomDataset.DEFAULT_INPUT_LEN,
664+
batchsize: int = 1,
665+
is_reranker: bool = True,
666+
**kwargs,
667+
) -> list[SampleRequest]:
668+
n_sep_tokens = int(is_reranker)
669+
670+
query_len_param = (input_len // 2) - n_sep_tokens if is_reranker else input_len
671+
672+
query_lens, _, query_offsets = self.get_sampling_params(
673+
1, range_ratio, query_len_param, 0, tokenizer
674+
)
675+
676+
query_len = int(query_lens[0])
677+
678+
if not is_reranker:
679+
assert num_requests > 1 and batchsize > 1
680+
num_requests -= 1
681+
batchsize -= 1
682+
doc_len_param = input_len
683+
else:
684+
doc_len_param = input_len - query_len - n_sep_tokens
685+
686+
doc_lens, _, doc_offsets = self.get_sampling_params(
687+
num_requests, range_ratio, doc_len_param, 0, tokenizer
688+
)
689+
vocab_size = tokenizer.vocab_size
690+
691+
query_prompt, query_input_len, token_mismatch_total = (
692+
self.generate_token_sequence(
693+
tokenizer=tokenizer,
694+
prefix_token_ids=[],
695+
prefix_len=0,
696+
vocab_size=vocab_size,
697+
input_len=query_len,
698+
offset=int(query_offsets[0]),
699+
index=0,
700+
)
701+
)
702+
703+
requests = []
704+
for i in range(num_requests):
705+
prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501
706+
tokenizer=tokenizer,
707+
prefix_token_ids=[],
708+
prefix_len=0,
709+
vocab_size=vocab_size,
710+
input_len=int(doc_lens[i]),
711+
offset=int(doc_offsets[i]),
712+
index=i + 1,
713+
)
714+
token_mismatch_total += token_mismatch
715+
requests.append((prompt, total_input_len))
716+
717+
batch_requests = []
718+
# Create batched requests
719+
for i in range(0, num_requests, batchsize):
720+
batch = requests[i : i + batchsize]
721+
query_contrib = (
722+
(query_input_len + n_sep_tokens) * len(batch)
723+
if is_reranker
724+
else query_input_len
725+
)
726+
batch_requests.append(
727+
SampleRequest(
728+
prompt=[query_prompt] + [req[0] for req in batch],
729+
prompt_len=query_contrib + sum(req[1] for req in batch),
730+
expected_output_len=0,
731+
request_id=request_id_prefix + str(i // batchsize),
732+
)
733+
)
734+
735+
if token_mismatch_total != 0:
736+
logger.warning(
737+
"Across all generated prompts, there were %d %s tokens "
738+
"than expected after decoding and re-encoding. This is "
739+
"expected due to the imperfect nature of the sampling "
740+
"procedure.",
741+
abs(token_mismatch_total),
742+
"more" if token_mismatch_total > 0 else "fewer",
743+
)
744+
745+
return batch_requests
746+
747+
641748
# -----------------------------------------------------------------------------
642749
# MultiModalDataset Implementation
643750
# -----------------------------------------------------------------------------
@@ -1149,6 +1256,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
11491256
"sonnet",
11501257
"random",
11511258
"random-mm",
1259+
"random-rerank",
11521260
"hf",
11531261
"custom",
11541262
"prefix_repetition",
@@ -1292,6 +1400,14 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
12921400
default=1,
12931401
help=("Batch size for random sampling. Only used for embeddings benchmark."),
12941402
)
1403+
random_group.add_argument(
1404+
"--no-reranker",
1405+
action="store_true",
1406+
help=(
1407+
"Whether the model supports reranking natively."
1408+
" Only used for reranker benchmark."
1409+
),
1410+
)
12951411

12961412
# random multimodal dataset options
12971413
random_mm_group = parser.add_argument_group(
@@ -1678,6 +1794,19 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
16781794
request_id_prefix=args.request_id_prefix,
16791795
no_oversample=args.no_oversample,
16801796
),
1797+
"random-rerank": lambda: RandomDatasetForReranking(
1798+
random_seed=args.seed,
1799+
dataset_path=args.dataset_path,
1800+
disable_shuffle=args.disable_shuffle,
1801+
).sample(
1802+
tokenizer=tokenizer,
1803+
num_requests=args.num_prompts,
1804+
input_len=args.random_input_len,
1805+
range_ratio=args.random_range_ratio,
1806+
request_id_prefix=args.request_id_prefix,
1807+
batchsize=args.random_batch_size,
1808+
is_reranker=not args.no_reranker,
1809+
),
16811810
"prefix_repetition": lambda: PrefixRepetitionRandomDataset(
16821811
random_seed=args.seed,
16831812
dataset_path=args.dataset_path,

vllm/benchmarks/lib/endpoint_request_func.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def add_chunk(self, chunk_bytes: bytes) -> list[str]:
6464
class RequestFuncInput:
6565
"""The input for the request function."""
6666

67-
prompt: str
67+
prompt: str | list[str]
6868
api_url: str
6969
prompt_len: int
7070
output_len: int
@@ -484,7 +484,7 @@ def to_bytes(y, sr):
484484
return output
485485

486486

487-
async def _run_openai_embeddings(
487+
async def _run_pooling_request(
488488
session: aiohttp.ClientSession,
489489
api_url: str,
490490
payload: dict[str, Any],
@@ -497,7 +497,7 @@ async def _run_openai_embeddings(
497497
try:
498498
async with session.post(url=api_url, headers=headers, json=payload) as response:
499499
if response.status == 200:
500-
output.latency = time.perf_counter() - st
500+
output.ttft = output.latency = time.perf_counter() - st
501501
data = await response.json()
502502
output.success = True
503503
output.generated_text = ""
@@ -536,7 +536,43 @@ async def async_request_openai_embeddings(
536536
}
537537
_update_headers_common(headers, request_func_input)
538538

539-
return await _run_openai_embeddings(
539+
return await _run_pooling_request(
540+
session,
541+
api_url,
542+
payload=payload,
543+
headers=headers,
544+
pbar=pbar,
545+
)
546+
547+
548+
async def async_request_vllm_rerank(
549+
request_func_input: RequestFuncInput,
550+
session: aiohttp.ClientSession,
551+
pbar: tqdm | None = None,
552+
) -> RequestFuncOutput:
553+
api_url = request_func_input.api_url
554+
_validate_api_url(api_url, "vLLM score API", "rerank")
555+
556+
assert (
557+
isinstance(request_func_input.prompt, list)
558+
and len(request_func_input.prompt) > 1
559+
)
560+
561+
payload = {
562+
"model": request_func_input.model_name
563+
if request_func_input.model_name
564+
else request_func_input.model,
565+
"query": request_func_input.prompt[0],
566+
"documents": request_func_input.prompt[1:],
567+
}
568+
569+
headers = {
570+
"Content-Type": "application/json",
571+
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
572+
}
573+
_update_headers_common(headers, request_func_input)
574+
575+
return await _run_pooling_request(
540576
session,
541577
api_url,
542578
payload=payload,
@@ -572,7 +608,7 @@ async def async_request_openai_embeddings_chat(
572608
}
573609
_update_headers_common(headers, request_func_input)
574610

575-
return await _run_openai_embeddings(
611+
return await _run_pooling_request(
576612
session,
577613
api_url,
578614
payload=payload,
@@ -685,7 +721,7 @@ async def async_request_infinity_embeddings(
685721
}
686722
_update_headers_common(headers, request_func_input)
687723

688-
return await _run_openai_embeddings(
724+
return await _run_pooling_request(
689725
session,
690726
api_url,
691727
payload=payload,
@@ -722,6 +758,7 @@ async def async_request_infinity_embeddings_clip(
722758
"infinity-embeddings": async_request_infinity_embeddings,
723759
"infinity-embeddings-clip": async_request_infinity_embeddings_clip,
724760
# (Infinity embedding server does not support vlm2vec)
761+
"vllm-rerank": async_request_vllm_rerank,
725762
}
726763

727764
OPENAI_COMPATIBLE_BACKENDS = [

0 commit comments

Comments
 (0)