-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathutils.py
57 lines (40 loc) · 1.64 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import re
from typing import List, Literal, Union
import yaml
from pydantic import BaseModel
class Example(BaseModel):
question: str
answer: str
class Config(BaseModel):
system_prompt: str
format: str
fewshot: List[Example]
def load_config(task: Literal["gsm8k", "date"], config: Literal["baseline", "cot", "cod"]) -> Config:
with open(f"./configs/{task}_{config}.yaml") as f:
return Config.model_validate(yaml.safe_load(f))
def compose_request(config: Config, shot: int, question: str) -> str:
request = config.system_prompt + "\n"
if shot is None:
shot = len(config.fewshot)
if shot != 0:
fewshot = [config.format.format(question=ex.question, answer=ex.answer) for ex in config.fewshot[:shot]]
request += "\n".join(fewshot) + "\n"
request += config.format.format(question=question, answer="")
return request
def nth_percentile(values: list[float], percentile: float) -> float:
values = sorted(values)
index = min(round(percentile * len(values)), len(values)) - 1
return values[index]
def average(values: list[float]) -> float:
return sum(values) / len(values)
def trimmed_average(values: list[float], percentile: float) -> float:
values = sorted(values)
count = round(len(values) * percentile)
trimmed = values[count : len(values) - count]
return average(trimmed)
def extract_number_from_string(s: str) -> Union[int, float]:
match = re.search(r"\d{1,3}(?:,\d{3})*(?:\.\d+)?", s)
if match:
number_str = match.group().replace(",", "") # Remove commas
return float(number_str) if "." in number_str else int(number_str)
return None