Skip to content

Commit ec8a5e5

Browse files
authored
[Misc]: Add support for goodput on guided benchmarking + TPOT calculation refactor (#13736)
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
1 parent 215bf15 commit ec8a5e5

File tree

1 file changed

+82
-5
lines changed

1 file changed

+82
-5
lines changed

benchmarks/benchmark_serving_guided.py

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
1010
1111
On the client side, run:
12-
python benchmarks/benchmark_serving.py \
12+
python benchmarks/benchmark_serving_guided.py \
1313
--backend <backend> \
1414
--model <your_model> \
1515
--dataset json \
@@ -31,7 +31,7 @@
3131
import time
3232
import warnings
3333
from dataclasses import dataclass
34-
from typing import AsyncGenerator, List, Optional, Tuple
34+
from typing import AsyncGenerator, Dict, List, Optional, Tuple
3535

3636
import datasets
3737
import numpy as np
@@ -264,6 +264,7 @@ def calculate_metrics(
264264
tokenizer: PreTrainedTokenizerBase,
265265
selected_percentile_metrics: List[str],
266266
selected_percentiles: List[float],
267+
goodput_config_dict: Optional[Dict[str, float]] = None,
267268
) -> Tuple[BenchmarkMetrics, List[int]]:
268269
actual_output_lens: List[int] = []
269270
total_input = 0
@@ -287,10 +288,10 @@ def calculate_metrics(
287288
total_input += input_requests[i].prompt_len
288289
tpot = 0
289290
if output_len > 1:
290-
tpot = (outputs[i].latency - outputs[i].ttft) / (output_len -
291-
1)
291+
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
292+
tpot = latency_minus_ttft / (output_len - 1)
292293
tpots.append(tpot)
293-
outputs[i].tpot = sum(tpots) / len(tpots) if len(tpots) else 0
294+
outputs[i].tpot = tpot
294295
# Note: if output_len <= 1, we regard tpot as 0 for goodput
295296
all_tpots.append(tpot)
296297
itls += outputs[i].itl
@@ -300,6 +301,28 @@ def calculate_metrics(
300301
else:
301302
actual_output_lens.append(0)
302303

304+
if goodput_config_dict:
305+
valid_metrics = []
306+
slo_values = []
307+
308+
if "ttft" in goodput_config_dict:
309+
valid_metrics.append(ttfts)
310+
slo_values.append(goodput_config_dict["ttft"] /
311+
MILLISECONDS_TO_SECONDS_CONVERSION)
312+
if "tpot" in goodput_config_dict:
313+
valid_metrics.append(all_tpots)
314+
slo_values.append(goodput_config_dict["tpot"] /
315+
MILLISECONDS_TO_SECONDS_CONVERSION)
316+
if "e2el" in goodput_config_dict:
317+
valid_metrics.append(e2els)
318+
slo_values.append(goodput_config_dict["e2el"] /
319+
MILLISECONDS_TO_SECONDS_CONVERSION)
320+
321+
for req_metric in zip(*valid_metrics):
322+
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
323+
if is_good_req:
324+
good_completed += 1
325+
303326
if completed == 0:
304327
warnings.warn(
305328
"All requests failed. This is likely due to a misconfiguration "
@@ -356,6 +379,7 @@ async def benchmark(
356379
max_concurrency: Optional[int],
357380
guided_decoding_ratio: float,
358381
guided_decoding_backend: str,
382+
goodput_config_dict: Optional[Dict[str, float]] = None,
359383
):
360384
if backend in ASYNC_REQUEST_FUNCS:
361385
request_func = ASYNC_REQUEST_FUNCS[backend]
@@ -483,6 +507,7 @@ async def limited_request_func(request_func_input, pbar):
483507
tokenizer=tokenizer,
484508
selected_percentile_metrics=selected_percentile_metrics,
485509
selected_percentiles=selected_percentiles,
510+
goodput_config_dict=goodput_config_dict,
486511
)
487512

488513
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
@@ -494,6 +519,9 @@ async def limited_request_func(request_func_input, pbar):
494519
metrics.total_output))
495520
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
496521
metrics.request_throughput))
522+
if goodput_config_dict:
523+
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
524+
metrics.request_goodput))
497525
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
498526
metrics.output_throughput))
499527
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
@@ -617,6 +645,40 @@ def _eval_correctness(expected, actual):
617645
100) if len(not_none_scores) > 0 else None
618646

619647

648+
def parse_goodput(slo_pairs):
649+
goodput_config_dict = {}
650+
try:
651+
for slo_pair in slo_pairs:
652+
slo_name, slo_val = slo_pair.split(":")
653+
goodput_config_dict[slo_name] = float(slo_val)
654+
except ValueError as err:
655+
raise argparse.ArgumentTypeError(
656+
"Invalid format found for service level objectives. "
657+
"Specify service level objectives for goodput as \"KEY:VALUE\" "
658+
"pairs, where the key is a metric name, and the value is a "
659+
"number in milliseconds.") from err
660+
return goodput_config_dict
661+
662+
663+
def check_goodput_args(args):
664+
goodput_config_dict = {}
665+
VALID_NAMES = ["ttft", "tpot", "e2el"]
666+
if args.goodput:
667+
goodput_config_dict = parse_goodput(args.goodput)
668+
for slo_name, slo_val in goodput_config_dict.items():
669+
if slo_name not in VALID_NAMES:
670+
raise ValueError(
671+
f"Invalid metric name found, {slo_name}: {slo_val}. "
672+
"The service level objective name should be one of "
673+
f"{str(VALID_NAMES)}. ")
674+
if slo_val < 0:
675+
raise ValueError(
676+
f"Invalid value found, {slo_name}: {slo_val}. "
677+
"The service level objective value should be "
678+
"non-negative.")
679+
return goodput_config_dict
680+
681+
620682
def main(args: argparse.Namespace):
621683
print(args)
622684
random.seed(args.seed)
@@ -661,6 +723,8 @@ def main(args: argparse.Namespace):
661723

662724
input_requests = sample_requests(tokenizer, args)
663725

726+
goodput_config_dict = check_goodput_args(args)
727+
664728
benchmark_result, ret = asyncio.run(
665729
benchmark(
666730
backend=backend,
@@ -681,6 +745,7 @@ def main(args: argparse.Namespace):
681745
max_concurrency=args.max_concurrency,
682746
guided_decoding_ratio=args.guided_decoding_ratio,
683747
guided_decoding_backend=args.guided_decoding_backend,
748+
goodput_config_dict=goodput_config_dict,
684749
))
685750

686751
# Save config and results to json
@@ -865,6 +930,18 @@ def main(args: argparse.Namespace):
865930
"Default value is \"99\". "
866931
"Use \"--percentile-metrics\" to select metrics.",
867932
)
933+
parser.add_argument(
934+
"--goodput",
935+
nargs="+",
936+
required=False,
937+
help="Specify service level objectives for goodput as \"KEY:VALUE\" "
938+
"pairs, where the key is a metric name, and the value is in "
939+
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
940+
"separated by spaces. Allowed request level metric names are "
941+
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
942+
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
943+
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
944+
868945
parser.add_argument("--no-guided-decoding",
869946
action='store_true',
870947
default=False,

0 commit comments

Comments
 (0)