-
Notifications
You must be signed in to change notification settings - Fork 203
/
benchmark.py
145 lines (121 loc) · 4.29 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import time
import argparse
from dotenv import load_dotenv
from distutils.util import strtobool
from memory_profiler import memory_usage
from tqdm import tqdm
from llama2_wrapper import LLAMA2_WRAPPER
def run_iteration(
llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS
):
def generation():
generator = llama2_wrapper.run(
prompt_example,
[],
DEFAULT_SYSTEM_PROMPT,
DEFAULT_MAX_NEW_TOKENS,
1,
0.95,
50,
)
model_response = None
try:
first_model_response = next(generator)
except StopIteration:
pass
for model_response in generator:
pass
return llama2_wrapper.get_token_length(model_response), model_response
tic = time.perf_counter()
mem_usage, (output_token_length, model_response) = memory_usage(
(generation,), max_usage=True, retval=True
)
toc = time.perf_counter()
generation_time = toc - tic
tokens_per_second = output_token_length / generation_time
return generation_time, tokens_per_second, mem_usage, model_response
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--iter", type=int, default=5, help="Number of iterations")
parser.add_argument("--model_path", type=str, default="", help="model path")
parser.add_argument(
"--backend_type",
type=str,
default="",
help="Backend options: llama.cpp, gptq, transformers",
)
parser.add_argument(
"--load_in_8bit",
type=bool,
default=False,
help="Whether to use bitsandbytes 8 bit.",
)
args = parser.parse_args()
load_dotenv()
DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "")
MAX_MAX_NEW_TOKENS = int(os.getenv("MAX_MAX_NEW_TOKENS", 2048))
DEFAULT_MAX_NEW_TOKENS = int(os.getenv("DEFAULT_MAX_NEW_TOKENS", 1024))
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", 4000))
MODEL_PATH = os.getenv("MODEL_PATH")
assert MODEL_PATH is not None, f"MODEL_PATH is required, got: {MODEL_PATH}"
BACKEND_TYPE = os.getenv("BACKEND_TYPE")
assert BACKEND_TYPE is not None, f"BACKEND_TYPE is required, got: {BACKEND_TYPE}"
LOAD_IN_8BIT = bool(strtobool(os.getenv("LOAD_IN_8BIT", "True")))
if args.model_path != "":
MODEL_PATH = args.model_path
if args.backend_type != "":
BACKEND_TYPE = args.backend_type
if args.load_in_8bit:
LOAD_IN_8BIT = True
# Initialization
init_tic = time.perf_counter()
llama2_wrapper = LLAMA2_WRAPPER(
model_path=MODEL_PATH,
backend_type=BACKEND_TYPE,
max_tokens=MAX_INPUT_TOKEN_LENGTH,
load_in_8bit=LOAD_IN_8BIT,
# verbose=True,
)
init_toc = time.perf_counter()
initialization_time = init_toc - init_tic
total_time = 0
total_tokens_per_second = 0
total_memory_gen = 0
prompt_example = (
"Can you explain briefly to me what is the Python programming language?"
)
# Cold run
print("Performing cold run...")
run_iteration(
llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS
)
# Timed runs
print(f"Performing {args.iter} timed runs...")
for i in tqdm(range(args.iter)):
try:
gen_time, tokens_per_sec, mem_gen, model_response = run_iteration(
llama2_wrapper,
prompt_example,
DEFAULT_SYSTEM_PROMPT,
DEFAULT_MAX_NEW_TOKENS,
)
total_time += gen_time
total_tokens_per_second += tokens_per_sec
total_memory_gen += mem_gen
except:
break
avg_time = total_time / (i + 1)
avg_tokens_per_second = total_tokens_per_second / (i + 1)
avg_memory_gen = total_memory_gen / (i + 1)
print(f"Last model response: {model_response}")
print(f"Initialization time: {initialization_time:0.4f} seconds.")
print(
f"Average generation time over {(i + 1)} iterations: {avg_time:0.4f} seconds."
)
print(
f"Average speed over {(i + 1)} iterations: {avg_tokens_per_second:0.4f} tokens/sec."
)
print(f"Average memory usage during generation: {avg_memory_gen:.2f} MiB")
if __name__ == "__main__":
main()