Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from tqdm import tqdm

from .types import EvalResult, Message, SingleEvalResult
from .eval_types import EvalResult, Message, SingleEvalResult


def _compute_stat(values: list, stat: str):
Expand Down Expand Up @@ -49,7 +49,10 @@ def aggregate_results(
key = name if stat == "mean" else f"{name}:{stat}"
final_metrics[key] = _compute_stat(values, stat)
return EvalResult(
score=final_metrics.pop("score", None), metrics=final_metrics, htmls=htmls, convos=convos
score=final_metrics.pop("score", None),
metrics=final_metrics,
htmls=htmls,
convos=convos,
)


Expand Down Expand Up @@ -87,7 +90,9 @@ def message_to_html(message: Message) -> str:
Generate HTML snippet (inside a <div>) for a message.
"""
return jinja_env.from_string(_message_template).render(
role=message["role"], content=message["content"], variant=message.get("variant", None)
role=message["role"],
content=message["content"],
variant=message.get("variant", None),
)


Expand Down Expand Up @@ -175,4 +180,6 @@ def make_report_from_example_htmls(htmls: list[str]):
"""
Create a standalone HTML report from a list of example htmls
"""
return jinja_env.from_string(_report_template).render(score=None, metrics={}, htmls=htmls)
return jinja_env.from_string(_report_template).render(
score=None, metrics={}, htmls=htmls
)
37 changes: 23 additions & 14 deletions drop_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from . import common
from .mmlu_eval import HTML_JINJA
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
from .eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult

"""
From here through _normalize_answer was originally copied from:
Expand Down Expand Up @@ -62,7 +62,9 @@ def _normalize_answer(text: str) -> str:
"""Lower text and remove punctuation, articles and extra whitespace."""

parts = [
_white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token)))))
_white_space_fix(
_remove_articles(_normalize_number(_remove_punc(_lower(token))))
)
for token in _tokenize(text)
]
parts = [part for part in parts if part.strip()]
Expand Down Expand Up @@ -152,7 +154,8 @@ def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bo


def get_drop_metrics(
predicted: Union[str, List[str], Tuple[str, ...]], gold: Union[str, List[str], Tuple[str, ...]]
predicted: Union[str, List[str], Tuple[str, ...]],
gold: Union[str, List[str], Tuple[str, ...]],
) -> Tuple[float, float]:
"""
Takes a predicted answer and a gold answer (that are both either a string or a list of
Expand All @@ -164,7 +167,9 @@ def get_drop_metrics(
predicted_bags = _answer_to_bags(predicted)
gold_bags = _answer_to_bags(gold)

if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(
gold_bags[0]
):
exact_match = 1.0
else:
exact_match = 0.0
Expand All @@ -189,7 +194,9 @@ def answer_json_to_strings(answer: Dict[str, Any]) -> Tuple[Tuple[str, ...], str
tuple(
[
"{0} {1} {2}".format(
answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]
answer["date"]["day"],
answer["date"]["month"],
answer["date"]["year"],
).strip()
]
),
Expand Down Expand Up @@ -237,16 +244,14 @@ def drop_metric(sample: str, reference: list[str]) -> Tuple[float, float]:


class DropEval(Eval):
def __init__(self, num_examples: int | None = None, train_samples_per_prompt: int = 3):
def __init__(
self, num_examples: int | None = None, train_samples_per_prompt: int = 3
):
self.seed = 42
self._num_examples = num_examples
self._train_samples_per_prompt = train_samples_per_prompt
self.train_jsonl = (
"https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_train.jsonl.gz"
)
self.test_jsonl = (
"https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_dev.jsonl.gz"
)
self.train_jsonl = "https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_train.jsonl.gz"
self.test_jsonl = "https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_dev.jsonl.gz"
with gzip.GzipFile(fileobj=bf.BlobFile(self.train_jsonl, "rb"), mode="rb") as f:
self.train_samples = list(map(json.loads, f.readlines()))
with gzip.GzipFile(fileobj=bf.BlobFile(self.test_jsonl, "rb"), mode="rb") as f:
Expand Down Expand Up @@ -292,7 +297,9 @@ def fn(example: dict[str, str]):
for correct_answer in correct_answers
]
extracted_answers = [
extracted_answer for i in range(len(correct_answers)) if matches[i]
extracted_answer
for i in range(len(correct_answers))
if matches[i]
]
score = True in matches
html = common.jinja_env.from_string(HTML_JINJA).render(
Expand All @@ -302,7 +309,9 @@ def fn(example: dict[str, str]):
correct_answer=correct_answers,
extracted_answer=extracted_answers,
)
convo = prompt_messages + [dict(content=extracted_answer, role="assistant")]
convo = prompt_messages + [
dict(content=extracted_answer, role="assistant")
]
return SingleEvalResult(
html=html,
score=score,
Expand Down
File renamed without changes.
21 changes: 16 additions & 5 deletions gpqa_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from . import common
from .mmlu_eval import ANSWER_PATTERN, HTML_JINJA, QUERY_TEMPLATE
from .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult
from .eval_types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult


def format_question(row):
Expand All @@ -24,7 +24,9 @@ def __init__(
self,
n_repeats: int = 4,
variant: str = "diamond",
num_examples: int | None = None, # restrict to a subset of the data for debugging
num_examples: (
int | None
) = None, # restrict to a subset of the data for debugging
):
df = pandas.read_csv(
bf.BlobFile(
Expand All @@ -37,7 +39,9 @@ def __init__(
assert n_repeats == 1, "n_repeats only supported for num_examples"
examples = rng.sample(examples, num_examples)
examples = examples * n_repeats
examples = [example | {"permutation": rng.sample(range(4), 4)} for example in examples]
examples = [
example | {"permutation": rng.sample(range(4), 4)} for example in examples
]
self.examples = examples
self.n_repeats = n_repeats

Expand All @@ -53,7 +57,11 @@ def fn(row: dict):
correct_index = choices.index(row["Correct Answer"])
correct_answer = "ABCD"[correct_index]
choices_dict = dict(
A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=row["Question"]
A=choices[0],
B=choices[1],
C=choices[2],
D=choices[3],
Question=row["Question"],
)
prompt_messages = [dict(content=format_question(choices_dict), role="user")]
response_text = sampler(prompt_messages)
Expand All @@ -69,7 +77,10 @@ def fn(row: dict):
)
convo = prompt_messages + [dict(content=response_text, role="assistant")]
return SingleEvalResult(
html=html, score=score, convo=convo, metrics={"chars": len(response_text)}
html=html,
score=score,
convo=convo,
metrics={"chars": len(response_text)},
)

results = common.map_with_progress(fn, self.examples)
Expand Down
9 changes: 6 additions & 3 deletions humaneval_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from . import common
from .mmlu_eval import HTML_JINJA
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
from .eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult


def evaluate_functional_correctness(
Expand Down Expand Up @@ -84,9 +84,12 @@ def find_code(completion):
return extracted_answer

def fn(sample: dict[str, str]):
prompt_messages = [{"role": "user", "content": instruction + sample["prompt"]}]
prompt_messages = [
{"role": "user", "content": instruction + sample["prompt"]}
]
completions = [
find_code(sampler(prompt_messages)) for _ in range(self._num_samples_per_task)
find_code(sampler(prompt_messages))
for _ in range(self._num_samples_per_task)
]
results = evaluate_functional_correctness(sample, completions)
total = len(results)
Expand Down
6 changes: 4 additions & 2 deletions math_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pandas

from . import common
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
from .eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult

QUERY_TEMPLATE = """
Solve the following math problem step by step. The last line of your response should be of the form ANSWER: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
Expand Down Expand Up @@ -105,7 +105,9 @@ def fn(row: dict):
response_text = sampler(prompt_messages)
match = re.search(ANSWER_PATTERN, response_text)
extracted_answer = match.group(1) if match else None
score = float(check_equality(self.equality_checker, row["Answer"], extracted_answer))
score = float(
check_equality(self.equality_checker, row["Answer"], extracted_answer)
)
html = common.jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
Expand Down
6 changes: 4 additions & 2 deletions mgsm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from . import common
from .mmlu_eval import HTML_JINJA
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
from .eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult

ALL_LANGUAGES = ["bn", "de", "en", "es", "fr", "ja", "ru", "sw", "te", "th", "zh"]
LATIN_LANGUAGES = ["de", "en", "es", "fr", "sw"]
Expand Down Expand Up @@ -155,7 +155,9 @@ def __init__(
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(example: dict[str, str]):
language = example["lang"]
latin_language = "group_latin" if language in LATIN_LANGUAGES else "group_non_latin"
latin_language = (
"group_latin" if language in LATIN_LANGUAGES else "group_non_latin"
)
correct_answer = example["targets"]
instructoin = LANG_TO_INSTRUCTIONS[language]
prompt_messages = [
Expand Down
10 changes: 7 additions & 3 deletions mmlu_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pandas

from . import common
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
from .eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult

QUERY_TEMPLATE = """
Answer the following multiple choice question. The last line of your response should be of the following format: 'ANSWER: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
Expand Down Expand Up @@ -95,7 +95,9 @@ def format_question(row):
class MMLUEval(Eval):
def __init__(self, num_examples: int | None = None):
df = pandas.read_csv(
bf.BlobFile("https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv")
bf.BlobFile(
"https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv"
)
)
examples = [row.to_dict() for _, row in df.iterrows()]
if num_examples:
Expand All @@ -118,7 +120,9 @@ def fn(row: dict):
)
convo = prompt_messages + [dict(content=response_text, role="assistant")]
category = subject2category.get(row["Subject"], "other")
return SingleEvalResult(html=html, score=score, metrics={category: score}, convo=convo)
return SingleEvalResult(
html=html, score=score, metrics={category: score}, convo=convo
)

results = common.map_with_progress(fn, self.examples)
return common.aggregate_results(results)
Expand Down
6 changes: 4 additions & 2 deletions sampler/chat_completion_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import openai
from openai import OpenAI

from ..types import MessageList, SamplerBase
from ..eval_types import MessageList, SamplerBase

OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
Expand Down Expand Up @@ -35,7 +35,9 @@ def __init__(

def __call__(self, message_list: MessageList) -> str:
if self.system_message:
message_list = [{"role": "system", "content": self.system_message}] + message_list
message_list = [
{"role": "system", "content": self.system_message}
] + message_list
trial = 0
while True:
try:
Expand Down
2 changes: 1 addition & 1 deletion sampler/claude_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import anthropic

from ..types import MessageList, SamplerBase
from ..eval_types import MessageList, SamplerBase

CLAUDE_SYSTEM_MESSAGE_LMSYS = (
"The assistant is Claude, created by Anthropic. The current date is "
Expand Down