-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathevaluate.py
95 lines (88 loc) · 2.86 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import csv
import os
from llm_client import LLMClient
from tasks.coin_flip import CoinFlip
from tasks.date import DateUnderstanding
from tasks.gsm8k import GSM8K
from tasks.sports import SportsUnderstanding
from utils import average, nth_percentile
MODEL_MAPPING = {
"gpt-4o": "gpt-4o-2024-08-06",
"gpt-4o-mini": "gpt-4o-mini-2024-07-18",
"sonnet": "claude-3-5-sonnet-20240620",
"haiku": "claude-3-5-haiku-20241022",
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--task", choices=["gsm8k", "date", "sports", "coin_flip"])
parser.add_argument("--model", default="claude3.5")
parser.add_argument(
"--prompt",
choices=["baseline", "cod", "cot"],
default="cod",
help="Prompting strategy",
)
parser.add_argument(
"--shot",
type=int,
default=None,
help="Number of fewshot to be included, by default, include all fewshot examples",
)
parser.add_argument(
"--url",
default=None,
help="Base url for llm model endpoint",
)
parser.add_argument(
"--api-key",
default=None,
help="API key for model access, will use api keys in environment variables for openai and claude models.",
)
args = parser.parse_args()
llm_client = LLMClient(args.url, args.api_key)
match args.task:
case "gsm8k":
task = GSM8K(llm_client)
case "date":
task = DateUnderstanding(llm_client)
case "sports":
task = SportsUnderstanding(llm_client)
case "coin_flip":
task = CoinFlip(llm_client)
case _:
raise ValueError("Invalid task")
model = MODEL_MAPPING.get(args.model, args.model)
accuracy = task.evaluate(model, args.prompt, args.shot)
results = [
[
"Accuracy",
"Avg Token #",
"Average Latency (s)",
"P90 Latency (s)",
"P95 Latency (s)",
"P99 Latency (s)",
],
[
accuracy,
average(task.token_count_tracker),
average(task.latency_tracker),
nth_percentile(task.latency_tracker, 0.9),
nth_percentile(task.latency_tracker, 0.95),
nth_percentile(task.latency_tracker, 0.99),
],
]
for i in range(len(results[0])):
print(f"{results[0][i]}: {results[1][i]}")
if not os.path.exists("./results"):
os.makedirs("./results")
model_name = args.model.split(":")[1] if ":" in args.model else args.model
model_name = model_name.replace("/", "_")
fname = (
f"{args.task}-{model_name}-{args.prompt}-{args.shot}"
if args.shot
else f"{args.task}-{model_name}-{args.prompt}"
)
with open(f"./results/{fname}.csv", "w") as f:
writer = csv.writer(f)
writer.writerows(results)