-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathbenchmark_serving.py
983 lines (865 loc) · 29.5 KB
/
benchmark_serving.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Benchmark JetStream online serving.
On the server side, run one of the following commands:
* For real server, you need to pass correct server config (include the
model config that being passed into your engine impl) to the command
below. Refer to config_lib.py and implementations/mock/config.py for
config impl detail.
(run with real server)
python -m jetstream.core.implementations.<your_impl>.server \
--config <your_server_config>
(run with mock server)
python -m jetstream.core.implementations.mock.server
On the client side, run:
* For real server and shareGPT dataset, you need to pass the tokenizer,
server config, and dataset flags to the command below, and make some
changes to the tokenizer logic in the benchmark script (get_tokenizer
and sample_requests func) to use your tokenizer correctly.
* Add `--save-result` flag to save the benchmark result to a json file in
current folder.
* You can also add `--run_eval true` if you want to calculate ROUGE score
on the predicted outputs.
(run with real model and engines)
python -m benchmarks.benchmark_serving \
--tokenizer <your_tokenizer> \
--dataset <target_dataset_name> \
--dataset-path <target_dataset_path> \
--request-rate <request_rate>
(run with mock)
python -m benchmarks.benchmark_serving \
--request-rate 1
e2e example:
python3 benchmark_serving.py \
--tokenizer /home/{username}/maxtext/assets/tokenizer \
--num-prompts 100 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json
"""
import argparse
import asyncio
from dataclasses import dataclass, field
from datetime import datetime
import gc
import json
import random
import time
from typing import Any, AsyncGenerator, Optional
import os
import grpc
from benchmarks.metrics import EventMetric, CounterMetric
from jetstream.core.proto import jetstream_pb2
from jetstream.core.proto import jetstream_pb2_grpc
from jetstream.engine.token_utils import load_vocab
from jetstream.external_tokenizers.llama3 import llama3_tokenizer
import numpy as np
from tqdm.asyncio import tqdm # pytype: disable=pyi-error
import pandas
from eval_accuracy import eval_accuracy
from transformers import AutoTokenizer
def str2bool(v: str) -> bool:
"""Convert a string of truth to True or False.
Args:
- v (str):
- True values are 'y', 'yes', 't', 'true', and '1';
- False values are 'n', 'no', 'f', 'false', and '0'.
Returns:
bool: True or False
Raises:
ValueError if v is anything else.
"""
v = v.lower()
true_values = ["y", "yes", "t", "true", "1"]
false_values = ["n", "no", "f", "false", "0"]
if v in true_values:
return True
elif v in false_values:
return False
else:
raise ValueError(f"Invalid value '{v}'!")
class AsyncCounter:
"""An counter class for counting and quota management with asycio,
not thread safe. It's safe with asyncio as value changes are done
outside of await statements.
"""
def __init__(self, init_value: int, block_on_zero_seconds=0.002):
"""
Args:
init_value: Initial value for the counter.
block_on_zero_seconds: if greater than 0, the counter will spin when
value hits 0, hence can be used for quota management.
"""
self._init_value = init_value
self._value = init_value
self._block_on_zero_seconds = block_on_zero_seconds
async def inc(self):
self._value += 1
async def dec(self):
while True:
if self._value > 0 or self._block_on_zero_seconds <= 0.0:
self._value -= 1
return
await asyncio.sleep(self._block_on_zero_seconds)
def value(self):
return self._value
def delta(self):
return self._init_value - self._value
@dataclass
class BenchmarkMetrics:
"""Data class to store benchmark metrics."""
completed: int
total_input: int
total_output: int
request_throughput: float
input_throughput: float
output_throughput: float
ttft: EventMetric # Time-to-first-token
ttst: EventMetric # Time-to-second-token
tpot: EventMetric # Time-per-output-token
@dataclass
class InputRequest:
prompt: str = ""
prompt_len: int = 0
output: str = ""
output_len: int = 0
sample_idx: int = -1
@dataclass
class RequestFuncOutput:
"""Data class to store the response of a request."""
input_request: Optional[InputRequest] = None
generated_token_list: list[int] = field(default_factory=list)
generated_text: str = ""
success: bool = False
latency_sec: float = 0
ttft_sec: float = 0
ttst_sec: float = 0
prompt_len: int = 0
# Flatten the structure and return only the necessary results
def to_dict(self):
if self.input_request:
prompt = self.input_request.prompt
original_output = self.input_request.output
sample_idx = self.input_request.sample_idx
else:
prompt = None
original_output = None
sample_idx = None
return {
"prompt": prompt,
"original_output": original_output,
"generated_text": self.generated_text,
"success": self.success,
"latency_sec": self.latency_sec,
"ttft_sec": self.ttft_sec,
"ttst_sec": self.ttst_sec,
"prompt_len": self.prompt_len,
"sample_idx": sample_idx,
}
def get_tokenizer(
model_id: str,
tokenizer_name: str,
use_hf_tokenizer: bool,
) -> Any:
"""Return a tokenizer or a tokenizer placholder."""
if tokenizer_name == "test":
print("Using test tokenizer")
return "test"
elif use_hf_tokenizer:
# Please accept agreement to access private/gated models in HF, and
# follow up instructions below to set up access token
# https://huggingface.co/docs/transformers.js/en/guides/private
print(f"Using HuggingFace tokenizer: {tokenizer_name}")
return AutoTokenizer.from_pretrained(tokenizer_name)
elif model_id == "llama-3":
# Llama 3 uses a tiktoken tokenizer.
print(f"Using llama-3 tokenizer: {tokenizer_name}")
return llama3_tokenizer.Tokenizer(tokenizer_name)
else:
# Use JetStream tokenizer util. It's using the sentencepiece wrapper in
# seqio library.
print(f"Using tokenizer: {tokenizer_name}")
vocab = load_vocab(tokenizer_name)
return vocab.tokenizer
def load_sharegpt_dataset(
dataset_path: str,
conversation_starter: str,
) -> list[tuple[Any, Any]]:
# Load the dataset.
with open(dataset_path, "r", encoding="utf-8") as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Filter based on conversation starter
if conversation_starter != "both":
dataset = [
data
for data in dataset
if data["conversations"][0]["from"] == conversation_starter
]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
return dataset
def load_openorca_dataset_pkl(
dataset_path: str,
) -> list[tuple[Any, Any]]:
if not dataset_path:
dataset_path = "open_orca_gpt4_tokenized_llama.calibration_1000.pkl"
# read pickle file
samples = pandas.read_pickle(
os.path.join(
os.path.dirname(os.path.relpath(__file__)),
dataset_path,
)
)
prompts = []
outputs = []
for _, row in samples.iterrows():
prompts.append(row["input"])
outputs.append(row["output"])
return [(prompt, output) for prompt, output in zip(prompts, outputs)]
def tokenize_dataset(
dataset: list[tuple[Any, Any, Any]],
tokenizer: Any,
) -> list[tuple[str, Any, str, int, int, int]]:
n = len(dataset)
prompts = []
outputs = []
indices = []
prompt_token_ids = []
outputs_token_ids = []
for prompt, output, idx in dataset:
prompts.append(prompt)
outputs.append(output)
indices.append(idx)
prompt_token_ids.append(tokenizer.encode(prompt))
outputs_token_ids.append(tokenizer.encode(output))
tokenized_dataset = []
for i in range(n):
prompt_len = len(prompt_token_ids[i])
output_len = len(outputs_token_ids[i])
tokenized_data = (
prompts[i],
prompt_token_ids[i],
outputs[i],
prompt_len,
output_len,
indices[i],
)
tokenized_dataset.append(tokenized_data)
return tokenized_dataset
def filter_dataset(
tokenized_dataset: list[tuple[str, Any, str, int, int, int]],
max_output_length: int = 0,
) -> list[InputRequest]:
if max_output_length != 0:
print("In InputRequest, pass in actual output_length for each sample")
else:
print(
f"In InputRequest, pass in max_output_length: {max_output_length} for"
" each sample"
)
# Filter out too long sequences.
filtered_dataset: list[InputRequest] = []
for (
prompt,
_,
output,
prompt_len,
output_len,
sample_idx,
) in tokenized_dataset:
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
# This is because TGI causes errors when the input or output length
# is too short.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
request = InputRequest(
prompt, prompt_len, output, max_output_length or output_len, sample_idx
)
filtered_dataset.append(request)
print(f"The dataset contains {len(tokenized_dataset)} samples.")
print(f"The filtered dataset contains {len(filtered_dataset)} samples.")
return filtered_dataset
def sample_requests(
dataset: list[tuple[Any, Any]],
tokenizer: Any,
num_requests: int,
max_output_length: int = 0,
oversample_multiplier: float = 1.2,
) -> list[InputRequest]:
# Original dataset size
n = len(dataset)
dataset_indices = range(n)
# Create necessary number of requests even if bigger than dataset size
sampled_indices = random.sample(
dataset_indices, min(int(num_requests * oversample_multiplier), n)
)
if num_requests > len(sampled_indices):
print(
f"Number of requests {num_requests} is larger than size of dataset"
f" {n}.\n",
"Repeating data to meet number of requests.\n",
)
sampled_indices = sampled_indices * int(
np.ceil(num_requests / len(sampled_indices))
)
print(f"{len(sampled_indices)=}")
# some of these will be filtered out, so sample more than we need
sampled_dataset = []
for i in sampled_indices:
sampled_data = dataset[i] + (dataset_indices[i],)
sampled_dataset.append(sampled_data)
tokenized_dataset = tokenize_dataset(sampled_dataset, tokenizer)
input_requests = filter_dataset(tokenized_dataset, max_output_length)
# Sample the requests.
if len(input_requests) > num_requests:
input_requests = random.sample(input_requests, num_requests)
return input_requests
async def get_request(
input_requests: list[InputRequest],
request_rate: float,
) -> AsyncGenerator[InputRequest, None]:
input_requests = iter(input_requests)
for request in input_requests:
yield request
if request_rate == 0.0:
# If the request rate is infinity, then we don't need to wait.
continue
# Sample the request interval from the exponential distribution.
interval = np.random.exponential(1.0 / request_rate)
# The next request will be sent after the interval.
await asyncio.sleep(interval)
def calculate_metrics(
input_requests: list[InputRequest],
outputs: list[RequestFuncOutput],
dur_s: float,
tokenizer: Any,
) -> BenchmarkMetrics:
total_output = 0
total_input = 0
completed = 0
ttft = EventMetric("ttft", "Time-to-first-token", "ms")
ttst = EventMetric("ttst", "Time-to-second-token", "ms")
per_out_token_lat = EventMetric("TPOT", "Time-per-output-token", "ms")
output_sizes = []
for i in range(len(outputs)):
if outputs[i].success:
completed += 1
output_len = len(
outputs[i].generated_token_list
if tokenizer != "test"
else ["Ċ", "Ō", "Ɵ"]
)
output_sizes.append(output_len)
total_output += output_len
total_input += input_requests[i].prompt_len
if output_len == 0:
print(
f"""-------- output_len is zero for {i}th request:,
output: {outputs[i]}"""
)
continue
ttft.record(outputs[i].ttft_sec * 1000)
ttst.record(outputs[i].ttst_sec * 1000)
per_out_token_lat.record(outputs[i].latency_sec / output_len * 1000)
print("Mean output size:", float(np.mean(output_sizes)))
print("Median output size:", float(np.median(output_sizes)))
print("P99 output size:", float(np.percentile(output_sizes, 99)))
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
total_output=total_output,
request_throughput=completed / dur_s,
input_throughput=total_input / dur_s,
output_throughput=total_output / dur_s,
ttft=ttft,
ttst=ttst,
tpot=per_out_token_lat,
)
return metrics
async def grpc_async_request(
api_url: str,
request: Any,
prefill_quota: AsyncCounter,
active_req_quota: AsyncCounter,
out_token_cnt: CounterMetric,
) -> tuple[list[int], float, float, float]:
"""Send grpc synchronous request since the current grpc server is sync."""
options = [("grpc.keepalive_timeout_ms", 10000)]
async with grpc.aio.insecure_channel(api_url, options=options) as channel:
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
request_start_time = time.perf_counter()
response = stub.Decode(request)
token_list = []
ttft = 0
ttst = 0
stream_resp_cnt = 0
async for resp in response:
stream_resp_cnt += 1
if stream_resp_cnt == 1:
await prefill_quota.inc()
ttft = time.perf_counter() - request_start_time
if ttft > 2.0:
print(datetime.now(), f"slow TTFT {ttft:.2f}", prefill_quota.value())
elif stream_resp_cnt == 2:
ttst = time.perf_counter() - request_start_time
resp_tokens = resp.stream_content.samples[0].token_ids
token_list.extend(resp_tokens)
out_token_cnt.increment(len(resp_tokens))
await active_req_quota.inc()
req_latency = time.perf_counter() - request_start_time
return token_list, ttft, ttst, req_latency
async def send_request(
api_url: str,
tokenizer: Any,
input_request: InputRequest,
prefill_quota: AsyncCounter,
active_req_quota: AsyncCounter,
req_complete_cnt: CounterMetric,
out_token_cnt: CounterMetric,
pbar: tqdm,
) -> RequestFuncOutput:
"""Send the request to JetStream server."""
# Tokenize on client side following MLPerf standard.
token_ids = tokenizer.encode(input_request.prompt)
# Send the request
request = jetstream_pb2.DecodeRequest(
token_content=jetstream_pb2.DecodeRequest.TokenContent(
token_ids=token_ids
),
max_tokens=input_request.output_len,
metadata=jetstream_pb2.DecodeRequest.Metadata(
start_time=time.perf_counter()
),
)
out_tokens, ttft_sec, ttst_sec, latency_sec = await grpc_async_request(
api_url,
request,
prefill_quota,
active_req_quota,
out_token_cnt,
)
req_complete_cnt.increment()
# Collect per-request output and metrics.
output = RequestFuncOutput()
output.input_request = input_request
output.prompt_len = input_request.prompt_len
output.ttft_sec = ttft_sec
output.ttst_sec = ttst_sec
output.latency_sec = latency_sec
output.generated_token_list = out_tokens
# generated_token_list is a list of token ids, decode it to generated_text.
output.generated_text = tokenizer.decode(out_tokens)
output.success = True
if pbar:
pbar.postfix = (
f"#reqs: {active_req_quota.delta()}/"
f"{active_req_quota.value()}; "
f"#prefill: {prefill_quota.delta()}/"
f"{prefill_quota.value()}"
)
pbar.update(1)
return output
async def benchmark(
api_url: str,
tokenizer: Any,
input_requests: list[InputRequest],
request_rate: float,
disable_tqdm: bool,
prefill_quota: AsyncCounter,
active_req_quota: AsyncCounter,
is_warmup: bool = False,
) -> tuple[dict[str, float | int], list[RequestFuncOutput]]:
"""Benchmark the online serving performance.
Args:
api_url: URL (e.g. host:port) of the JetStream server to send requests to.
tokenizer: The tokenizer used to convert texts into tokens that will be set
in requests.
input_requests: A list of requests to send.
request_rate: The number of requests to send per second.
disable_tqdm: Whether progress bar should be disabled or not.
prefill_quota: Quota for limiting pending prefill operations.
active_req_quota: Quota for limiting inflight requests.
is_warmup: Whether this run is to warm up the server.
Return:
A tuple containing the performance statistics for all requests and a list
of responses from the executed requests.
"""
print(f"Benchmarking with a total number of {len(input_requests)} requests")
print(f"Benchmarking with request rate of {request_rate}")
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
req_complete_cnt = CounterMetric(
"ReqCompleteCount", "Request Completion Counter"
)
out_token_cnt = CounterMetric("OutTokenCount", "OutToken Counter")
# Run benchmarking
tasks = []
benchmark_start_time = time.perf_counter()
async for request in get_request(input_requests, request_rate):
await prefill_quota.dec()
await active_req_quota.dec()
tasks.append(
asyncio.create_task(
send_request(
api_url=api_url,
tokenizer=tokenizer,
input_request=request,
prefill_quota=prefill_quota,
active_req_quota=active_req_quota,
req_complete_cnt=req_complete_cnt,
out_token_cnt=out_token_cnt,
pbar=pbar,
)
)
)
outputs = await asyncio.gather(*tasks)
if pbar is not None:
pbar.close()
# Compute metrics
output_metrics = {}
if not is_warmup:
# No need to calculate metrics when executing warmup requests
benchmark_duration = time.perf_counter() - benchmark_start_time
metrics = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
)
print(f"Successful requests: {metrics.completed}")
print(f"Benchmark duration: {benchmark_duration:2f} s")
print(f"Total input tokens: {metrics.total_input}")
print(f"Total generated tokens: {metrics.total_output}")
print(f"Request throughput: {metrics.request_throughput:.2f} requests/s")
print(f"Input token throughput: {metrics.input_throughput:.2f} tokens/s")
print(f"Output token throughput: {metrics.output_throughput:.2f} tokens/s")
print(f"{metrics.ttft.distribution_summary_str()}")
print(f"{metrics.ttst.distribution_summary_str()}")
print(f"{metrics.tpot.distribution_summary_str()}")
# Calculate one rate for each 10 sec window. Adjusts the window size if
# needed to use csv output below for plotting the rate over time.
window_size_sec = 10
print(
f"----- Request complete rate time series "
f"(window_size = {window_size_sec} sec) -----"
)
print(f"{req_complete_cnt.rate_over_window_to_csv(window_size_sec)}")
print(
f"----- Output token rate time series "
f"(window_size = {window_size_sec} sec) -----"
)
print(f"{out_token_cnt.rate_over_window_to_csv(window_size_sec)}")
output_metrics = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput,
}
output_metrics = {
**output_metrics,
**metrics.ttft.distribution_summary_dict(),
**metrics.ttst.distribution_summary_dict(),
**metrics.tpot.distribution_summary_dict(),
}
return output_metrics, outputs
def mock_requests(total_mock_requests: int):
"""Generates a list of mock requests containing mock data."""
data = []
for _ in range(total_mock_requests):
reqeust = InputRequest()
reqeust.prompt = f"Prompt {random.randint(1, 1000)}"
reqeust.prompt_len = random.randint(10, 100)
reqeust.out = f"Output {random.randint(1, 1000)}"
reqeust.output_len = random.randint(1, 10)
data.append(reqeust)
return data
def sample_warmup_requests(requests):
interesting_buckets = [
0,
16,
32,
64,
128,
256,
512,
1024,
]
for start, end in zip(interesting_buckets[:-1], interesting_buckets[1:]):
for request in requests:
if start < request.prompt_len <= end:
yield request
break
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
np.random.seed(args.seed)
model_id = args.model
tokenizer_id = args.tokenizer
use_hf_tokenizer = args.use_hf_tokenizer
prefill_quota = AsyncCounter(init_value=3)
active_req_quota = AsyncCounter(init_value=450)
api_url = f"{args.server}:{args.port}"
tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer)
if tokenizer == "test" or args.dataset == "test":
input_requests = mock_requests(
args.total_mock_requests
) # e.g. [("AB", 2, "AB", 3)]
else:
dataset = []
if args.dataset == "openorca":
dataset = load_openorca_dataset_pkl(args.dataset_path)
elif args.dataset == "sharegpt":
dataset = load_sharegpt_dataset(
args.dataset_path,
args.conversation_starter,
)
# A given args.max_output_length value is the max generation step,
# when the args.max_output_length is default to None, the sample's golden
# output length will be used to decide the generation step.
input_requests = sample_requests(
dataset=dataset,
tokenizer=tokenizer,
num_requests=args.num_prompts,
max_output_length=args.max_output_length,
)
warmup_requests = None
if args.warmup_mode == "full":
warmup_requests = input_requests
elif args.warmup_mode == "sampled":
warmup_requests = list(sample_warmup_requests(input_requests)) * 2
if warmup_requests:
print(f"Warmup (mode: {args.warmup_mode}) is starting.")
_, _ = asyncio.run(
benchmark(
api_url=api_url,
tokenizer=tokenizer,
input_requests=warmup_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
prefill_quota=prefill_quota,
active_req_quota=active_req_quota,
is_warmup=True,
)
)
print(f"Warmup (mode: {args.warmup_mode}) has completed.")
# TODO: Replace this with warmup complete signal once supported.
# Wait for server completely warmup before running the benchmark.
time.sleep(5)
benchmark_result, request_outputs = asyncio.run(
benchmark(
api_url=api_url,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
prefill_quota=prefill_quota,
active_req_quota=active_req_quota,
)
)
# Process output
output = [output.to_dict() for output in request_outputs]
if args.run_eval:
eval_json = eval_accuracy(output)
# Save config and results to json
if args.save_result:
# dimensions values are strings
dimensions_json = {}
# metrics values are numerical
metrics_json = {}
# Setup
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
dimensions_json["date"] = current_dt
dimensions_json["model_id"] = model_id
dimensions_json["tokenizer_id"] = tokenizer_id
if args.additional_metadata_metrics_to_save is not None:
dimensions_json = {
**dimensions_json,
**json.loads(args.additional_metadata_metrics_to_save),
}
metrics_json["num_prompts"] = args.num_prompts
# Traffic
metrics_json["request_rate"] = args.request_rate
metrics_json = {**metrics_json, **benchmark_result}
if args.run_eval:
metrics_json = {**metrics_json, **eval_json}
final_json = {}
final_json["metrics"] = metrics_json
final_json["dimensions"] = dimensions_json
# Save to file
base_model_id = model_id.split("/")[-1]
file_name = (
f"JetStream-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
)
with open(file_name, "w", encoding="utf-8") as outfile:
json.dump(final_json, outfile)
if args.save_request_outputs:
file_path = args.request_outputs_file_path
with open(file_path, "w", encoding="utf-8") as output_file:
json.dump(
output,
output_file,
indent=4,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark the online serving throughput."
)
parser.add_argument(
"--server",
type=str,
default="0.0.0.0",
help="Server address.",
)
parser.add_argument("--port", type=str, default=9000)
parser.add_argument(
"--dataset",
type=str,
default="test",
choices=["test", "sharegpt", "openorca"],
help="The dataset name.",
)
parser.add_argument("--dataset-path", type=str, help="Path to the dataset.")
parser.add_argument(
"--model",
type=str,
default="no_model",
help=(
"Name of the model like llama-2, llama-3, gemma. (it's just used to"
" label the benchmark, pick the tokenizer, the model config is"
" defined in config_lib, and passed as the server config flag when"
" we run the JetStream server)"
),
)
parser.add_argument(
"--tokenizer",
type=str,
default="test",
help=(
"Name or path of the tokenizer. (For mock model testing, use the"
" default value)"
),
)
parser.add_argument(
"--use-hf-tokenizer",
type=str2bool,
default=False,
help=(
"Whether to use tokenizer from HuggingFace. If so, set this flag"
" to True, and provide name of the tokenizer in the tokenizer flag."
),
)
parser.add_argument(
"--num-prompts",
type=int,
default=1000,
help=(
"Number of prompts to process. (number of sample requests we randomly"
" collect from dataset)"
),
)
parser.add_argument(
"--request-rate",
type=float,
default=0.0,
help=(
"Number of requests per second. If this is 0., "
"then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize "
"the request arrival times."
),
)
parser.add_argument(
"--total-mock-requests",
type=int,
default=150,
help="The maximum number of mock requests to send for benchmark testing.",
)
parser.add_argument(
"--max-output-length",
type=int,
default=0,
help=(
"The maximum output length for reference request. It would be passed"
" to `max_tokens` parameter of the JetStream's DecodeRequest proto,"
" and used in JetStream to control the output/decode length of a"
" sequence. It would not be used in the engine. We should always set"
" max_tokens <= (max_target_length - max_prefill_predict_length)."
" max_target_length is the maximum length of a sequence;"
" max_prefill_predict_length is the maximum length of the"
" input/prefill of a sequence. Default to 0, in this case, "
"the output length of the golden dataset would be passed."
),
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--disable-tqdm",
action="store_true",
help="Specify to disable tqdm progress bar.",
)
parser.add_argument(
"--save-result",
action="store_true",
help="Specify to save benchmark results to a json file",
)
parser.add_argument(
"--additional-metadata-metrics-to-save",
type=str,
help=(
"Additional metadata about the workload. Should be a dictionary in"
" the form of a string."
),
)
parser.add_argument(
"--save-request-outputs",
action="store_true",
help="Specify to store request outputs into a json file",
)
parser.add_argument(
"--request-outputs-file-path",
type=str,
default="/tmp/request-outputs.json",
help="File path to store request outputs",
)
parser.add_argument(
"--run-eval",
type=str2bool,
default=False,
help="Whether to run evaluation script on the saved outputs",
)
parser.add_argument(
"--warmup-mode",
type=str,
default="none",
choices=["none", "sampled", "full"],
help="Whether to warmup first, and set the warmup mode",
)
parser.add_argument(
"--conversation-starter",
type=str,
default="human",
choices=["human", "gpt", "both"],
help="What entity should be the one starting the conversations.",
)
parsed_args = parser.parse_args()
gc.disable()
main(parsed_args)