Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Benchmark] More accurate TPOT calc in benchmark_serving.py #12288

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class RequestFuncOutput:
generated_text: str = ""
success: bool = False
latency: float = 0.0
output_tokens: int = 0
ttft: float = 0.0 # Time to first token
itl: List[float] = field(
default_factory=list) # List of inter-token latencies
Expand Down Expand Up @@ -156,7 +157,7 @@ async def async_request_trt_llm(
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
ttft = timestamp - st
output.ttft = ttft

# Decoding phase
Expand Down Expand Up @@ -245,6 +246,9 @@ async def async_request_openai_completions(
"logprobs": request_func_input.logprobs,
"stream": True,
"ignore_eos": request_func_input.ignore_eos,
"stream_options": {
"include_usage": True,
},
}
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
Expand All @@ -256,7 +260,6 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len

generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
Expand All @@ -271,15 +274,16 @@ async def async_request_openai_completions(

chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
if chunk != "[DONE]":
data = json.loads(chunk)

# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if data["choices"][0]["text"]:
if choices := data.get("choices"):
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
timestamp = time.perf_counter()
# First token
if not first_chunk_received:
Expand All @@ -293,7 +297,10 @@ async def async_request_openai_completions(
most_recent_timestamp)

most_recent_timestamp = timestamp
generated_text += data["choices"][0]["text"]
generated_text += text
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
if first_chunk_received:
output.success = True
else:
Expand All @@ -302,7 +309,7 @@ async def async_request_openai_completions(
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!")
output.generated_text = generated_text
output.latency = latency
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
Expand Down Expand Up @@ -342,6 +349,9 @@ async def async_request_openai_chat_completions(
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"ignore_eos": request_func_input.ignore_eos,
"stream_options": {
"include_usage": True,
},
}
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
Expand All @@ -368,31 +378,32 @@ async def async_request_openai_chat_completions(

chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)

delta = data["choices"][0]["delta"]
if delta.get("content", None):
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
ttft = timestamp - st
output.ttft = ttft

# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)

generated_text += delta["content"]
generated_text += content
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")

most_recent_timestamp = timestamp

output.generated_text = generated_text
output.success = True
output.latency = latency
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
Expand Down
69 changes: 39 additions & 30 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import argparse
import asyncio
import base64
import gc
import io
import json
import os
Expand Down Expand Up @@ -423,7 +424,7 @@ def calculate_metrics(
tokenizer: PreTrainedTokenizerBase,
selected_percentile_metrics: List[str],
selected_percentiles: List[float],
gootput_config_dict: Dict[str, float],
goodput_config_dict: Dict[str, float],
) -> Tuple[BenchmarkMetrics, List[int]]:
actual_output_lens: List[int] = []
total_input = 0
Expand All @@ -436,19 +437,23 @@ def calculate_metrics(
e2els: List[float] = []
for i in range(len(outputs)):
if outputs[i].success:
# We use the tokenizer to count the number of output tokens for all
# serving backends instead of looking at len(outputs[i].itl) since
# multiple output tokens may be bundled together
# Note : this may inflate the output token count slightly
output_len = len(
tokenizer(outputs[i].generated_text,
add_special_tokens=False).input_ids)
output_len = outputs[i].output_tokens

if output_len is None:
# We use the tokenizer to count the number of output tokens
# for some serving backends instead of looking at
# len(outputs[i].itl) since multiple output tokens may be
# bundled together
# Note : this may inflate the output token count slightly
output_len = len(
tokenizer(outputs[i].generated_text,
add_special_tokens=False).input_ids)
actual_output_lens.append(output_len)
total_input += input_requests[i][1]
tpot = 0
if output_len > 1:
tpot = (outputs[i].latency - outputs[i].ttft) / (output_len -
1)
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
tpot = latency_minus_ttft / (output_len - 1)
tpots.append(tpot)
# Note: if output_len <= 1, we regard tpot as 0 for goodput
all_tpots.append(tpot)
Expand All @@ -459,21 +464,21 @@ def calculate_metrics(
else:
actual_output_lens.append(0)

if gootput_config_dict:
if goodput_config_dict:
valid_metrics = []
slo_values = []

if "ttft" in gootput_config_dict:
if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts)
slo_values.append(gootput_config_dict["ttft"] /
slo_values.append(goodput_config_dict["ttft"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
if "tpot" in gootput_config_dict:
if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots)
slo_values.append(gootput_config_dict["tpot"] /
slo_values.append(goodput_config_dict["tpot"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
if "e2el" in gootput_config_dict:
if "e2el" in goodput_config_dict:
valid_metrics.append(e2els)
slo_values.append(gootput_config_dict["e2el"] /
slo_values.append(goodput_config_dict["e2el"] /
MILLISECONDS_TO_SECONDS_CONVERSION)

for req_metric in zip(*valid_metrics):
Expand Down Expand Up @@ -537,7 +542,7 @@ async def benchmark(
selected_percentile_metrics: List[str],
selected_percentiles: List[str],
ignore_eos: bool,
gootput_config_dict: Dict[str, float],
goodput_config_dict: Dict[str, float],
max_concurrency: Optional[int],
):
if backend in ASYNC_REQUEST_FUNCS:
Expand Down Expand Up @@ -661,7 +666,7 @@ async def limited_request_func(request_func_input, pbar):
tokenizer=tokenizer,
selected_percentile_metrics=selected_percentile_metrics,
selected_percentiles=selected_percentiles,
gootput_config_dict=gootput_config_dict,
goodput_config_dict=goodput_config_dict,
)

print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
Expand All @@ -673,7 +678,7 @@ async def limited_request_func(request_func_input, pbar):
metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput))
if gootput_config_dict:
if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
Expand All @@ -688,7 +693,7 @@ async def limited_request_func(request_func_input, pbar):
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"request_goodput:":
metrics.request_goodput if gootput_config_dict else None,
metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
Expand Down Expand Up @@ -744,11 +749,11 @@ def process_one_metric(

def check_goodput_args(args):
# Check and parse goodput arguments
gootput_config_dict = {}
goodput_config_dict = {}
VALID_NAMES = ["ttft", "tpot", "e2el"]
if args.goodput:
gootput_config_dict = parse_goodput(args.goodput)
for slo_name, slo_val in gootput_config_dict.items():
goodput_config_dict = parse_goodput(args.goodput)
for slo_name, slo_val in goodput_config_dict.items():
if slo_name not in VALID_NAMES:
raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. "
Expand All @@ -759,22 +764,22 @@ def check_goodput_args(args):
f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be "
"non-negative.")
return gootput_config_dict
return goodput_config_dict


def parse_goodput(slo_pairs):
gootput_config_dict = {}
goodput_config_dict = {}
try:
for slo_pair in slo_pairs:
slo_name, slo_val = slo_pair.split(":")
gootput_config_dict[slo_name] = float(slo_val)
goodput_config_dict[slo_name] = float(slo_val)
except ValueError as err:
raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. "
"Specify service level objectives for goodput as \"KEY:VALUE\" "
"pairs, where the key is a metric name, and the value is a "
"number in milliseconds.") from err
return gootput_config_dict
return goodput_config_dict


def main(args: argparse.Namespace):
Expand Down Expand Up @@ -874,7 +879,11 @@ def main(args: argparse.Namespace):
else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")

gootput_config_dict = check_goodput_args(args)
goodput_config_dict = check_goodput_args(args)

# Avoid GC processing "static" data - reduce pause times.
gc.collect()
gc.freeze()

benchmark_result = asyncio.run(
benchmark(
Expand All @@ -896,7 +905,7 @@ def main(args: argparse.Namespace):
float(p) for p in args.metric_percentiles.split(",")
],
ignore_eos=args.ignore_eos,
gootput_config_dict=gootput_config_dict,
goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency,
))

Expand Down
Loading