Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,9 @@ dmypy.json

# Pyre type checker
.pyre/

*.DS_Store
.vscode/settings.json
*.tsv
*.csv
*.gz
86 changes: 48 additions & 38 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from tqdm import tqdm

from .types import EvalResult, Message, SamplerBase, SingleEvalResult
from eval_types import EvalResult, Message, SamplerBase, SingleEvalResult

QUERY_TEMPLATE_MULTICHOICE = """
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
Expand Down
9 changes: 9 additions & 0 deletions constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import os

EXA_API_KEY = os.getenv("EXA_API_KEY")
PERPLEXITY_API_KEY = os.getenv("PERPLEXITY_API_KEY")
YOU_API_KEY = os.getenv("YOU_API_KEY")
BRAVE_API_KEY = os.getenv("BRAVE_API_KEY")
BING_API_KEY = os.getenv("BING_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
SERPER_API_KEY = os.getenv("SERPER_API_KEY")
27 changes: 27 additions & 0 deletions data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import gzip
import requests
from pathlib import Path

def download_file(url: str, cache_dir: str = "data") -> str:
"""Download a file from URL and cache it locally"""
# Create cache directory if it doesn't exist
Path(cache_dir).mkdir(parents=True, exist_ok=True)

# Get filename from URL
filename = url.split("/")[-1]
cache_path = os.path.join(cache_dir, filename)

# Return cached file if it exists
if os.path.exists(cache_path):
return cache_path

# Download file
response = requests.get(url)
response.raise_for_status()

# Save to cache
with open(cache_path, "wb") as f:
f.write(response.content)

return cache_path
23 changes: 13 additions & 10 deletions drop_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
import random
import re
import string
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Set, Tuple, Union

import blobfile as bf
import numpy as np
from scipy.optimize import linear_sum_assignment

from . import common
from .common import ANSWER_PATTERN, HTML_JINJA
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
from common import ANSWER_PATTERN, HTML_JINJA, jinja_env, map_with_progress, aggregate_results
from eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult
from data_utils import download_file

"""
From here through _normalize_answer was originally copied from:
Expand All @@ -28,6 +27,7 @@
/eval/drop_eval.py
"""

# hellp wprl

def _remove_articles(text: str) -> str:
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
Expand Down Expand Up @@ -245,9 +245,12 @@ def __init__(self, num_examples: int | None = None, train_samples_per_prompt: in
self.test_jsonl = (
"https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_dev.jsonl.gz"
)
with gzip.GzipFile(fileobj=bf.BlobFile(self.train_jsonl, "rb"), mode="rb") as f:
train_path = download_file(self.train_jsonl)
test_path = download_file(self.test_jsonl)

with gzip.open(train_path, 'rt') as f:
self.train_samples = list(map(json.loads, f.readlines()))
with gzip.GzipFile(fileobj=bf.BlobFile(self.test_jsonl, "rb"), mode="rb") as f:
with gzip.open(test_path, 'rt') as f:
self.test_samples = list(map(json.loads, f.readlines()))
if self._num_examples:
self.test_samples = random.Random(self.seed).sample(
Expand Down Expand Up @@ -293,7 +296,7 @@ def fn(example: dict[str, str]):
extracted_answer for i in range(len(correct_answers)) if matches[i]
]
score = True in matches
html = common.jinja_env.from_string(HTML_JINJA).render(
html = jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=extracted_answer, role="assistant"),
score=score,
Expand All @@ -308,5 +311,5 @@ def fn(example: dict[str, str]):
metrics={"em_score": em_score, "f1_score": f1_score},
)

results = common.map_with_progress(fn, self.test_samples)
return common.aggregate_results(results)
results = map_with_progress(fn, self.test_samples)
return aggregate_results(results)
26 changes: 26 additions & 0 deletions types.py → eval_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Message = dict[str, Any] # keys role, content
MessageList = list[Message]

SearchResult = list[dict[str, Any]]

class SamplerBase:
"""
Expand All @@ -13,7 +14,32 @@ class SamplerBase:

def __call__(self, message_list: MessageList) -> str:
raise NotImplementedError

def __extract_query_from_messages__(self, message_list: MessageList) -> str:
"""Extract the last user message as the query"""

for message in reversed(message_list):
if message["role"] == "user":
if isinstance(message["content"], str):
return message["content"]
elif isinstance(message["content"], list):
return " ".join(
part["text"] for part in message["content"]
if isinstance(part, dict) and "text" in part
)

raise ValueError("No user message found in message list")

class SearchResultProvider:
"""
Base class for defining a search result provider.
"""

def __call__(self, query: str) -> str:
raise NotImplementedError

def __format_context__(self, results: SearchResult) -> str:
raise NotImplementedError

@dataclass
class EvalResult:
Expand Down
25 changes: 15 additions & 10 deletions gpqa_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@
import random
import re

import blobfile as bf
import pandas

from . import common
from .common import ANSWER_PATTERN_MULTICHOICE, HTML_JINJA, format_multichoice_question
from .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult
from data_utils import download_file
from common import (
HTML_JINJA,
ANSWER_PATTERN_MULTICHOICE,
format_multichoice_question,
jinja_env,
map_with_progress,
aggregate_results,
)
from eval_types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult


class GPQAEval(Eval):
Expand All @@ -22,9 +28,8 @@ def __init__(
variant: str = "diamond",
num_examples: int | None = None, # restrict to a subset of the data for debugging
):
df = pandas.read_csv(
bf.BlobFile(f"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_{variant}.csv")
)
filepath = download_file(f"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_{variant}.csv")
df = pandas.read_csv(filepath)
examples = [row.to_dict() for _, row in df.iterrows()]
rng = random.Random(0)
if num_examples:
Expand Down Expand Up @@ -58,7 +63,7 @@ def fn(row: dict):
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
extracted_answer = match.group(1) if match else None
score = 1.0 if extracted_answer == correct_answer else 0.0
html = common.jinja_env.from_string(HTML_JINJA).render(
html = jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
Expand All @@ -70,5 +75,5 @@ def fn(row: dict):
html=html, score=score, convo=convo, metrics={"chars": len(response_text)}
)

results = common.map_with_progress(fn, self.examples)
return common.aggregate_results(results)
results = map_with_progress(fn, self.examples)
return aggregate_results(results)
Empty file added haha.txt
Empty file.
27 changes: 11 additions & 16 deletions humaneval_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,21 @@
https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
"""

import json
import logging
import multiprocessing
import random
import re
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO
from typing import Any, Tuple

import blobfile as bf
import tqdm
from human_eval.data import HUMAN_EVAL, read_problems
from human_eval.data import read_problems
from human_eval.evaluation import estimate_pass_at_k
from human_eval.execution import check_correctness # , unsafe_execute

from . import common
from .common import HTML_JINJA
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult

from common import (
HTML_JINJA,
jinja_env,
map_with_progress,
aggregate_results,
)
from eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult

def evaluate_functional_correctness(
sample: dict[str, str],
Expand Down Expand Up @@ -94,7 +89,7 @@ def fn(sample: dict[str, str]):
total = len(results)
correct = sum(results)
score = sum(results) / len(results)
html = common.jinja_env.from_string(HTML_JINJA).render(
html = jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=completions[0], role="assistant"),
score=score,
Expand All @@ -116,5 +111,5 @@ def fn(sample: dict[str, str]):
},
)

results = common.map_with_progress(fn, self.examples, num_threads=3)
return common.aggregate_results(results)
results = map_with_progress(fn, self.examples, num_threads=3)
return aggregate_results(results)
26 changes: 15 additions & 11 deletions math_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@
import random
import re
from typing import Literal

import blobfile as bf
import pandas

from . import common
from .common import ANSWER_PATTERN, HTML_JINJA, check_equality
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
from data_utils import download_file
from common import (
HTML_JINJA,
ANSWER_PATTERN,
check_equality,
jinja_env,
map_with_progress,
aggregate_results,
)
from eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult

QUERY_TEMPLATE = """
Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
Expand All @@ -32,9 +37,8 @@ def __init__(
n_repeats: int = 16,
split: Literal["math_test", "math_500_test"] = "math_test",
):
df = pandas.read_csv(
bf.BlobFile(f"https://openaipublic.blob.core.windows.net/simple-evals/{split}.csv")
)
filepath = download_file(f"https://openaipublic.blob.core.windows.net/simple-evals/{split}.csv")
df = pandas.read_csv(filepath)
examples = [row.to_dict() for _, row in df.iterrows()]
if num_examples:
assert n_repeats == 1, "n_repeats only supported for num_examples = None"
Expand All @@ -52,7 +56,7 @@ def fn(row: dict):
match = re.search(ANSWER_PATTERN, response_text)
extracted_answer = match.group(1) if match else None
score = float(check_equality(self.equality_checker, row["Answer"], extracted_answer))
html = common.jinja_env.from_string(HTML_JINJA).render(
html = jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
Expand All @@ -62,5 +66,5 @@ def fn(row: dict):
convo = prompt_messages + [dict(content=response_text, role="assistant")]
return SingleEvalResult(html=html, score=score, convo=convo)

results = common.map_with_progress(fn, self.examples)
return common.aggregate_results(results)
results = map_with_progress(fn, self.examples)
return aggregate_results(results)
25 changes: 15 additions & 10 deletions mgsm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
import re
from typing import Optional

import blobfile as bf

from . import common
from .mmlu_eval import HTML_JINJA
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
from data_utils import download_file
from mmlu_eval import HTML_JINJA
from eval_types import Eval, EvalResult, SamplerBase, SingleEvalResult
from common import (
HTML_JINJA,
jinja_env,
map_with_progress,
aggregate_results,
)

ALL_LANGUAGES = ["bn", "de", "en", "es", "fr", "ja", "ru", "sw", "te", "th", "zh"]
LATIN_LANGUAGES = ["de", "en", "es", "fr", "sw"]
Expand Down Expand Up @@ -109,12 +113,13 @@ def score_mgsm(target: str, prediction: str) -> bool:
def get_lang_examples(lang: str) -> list[dict[str, str]]:
fpath = LANG_TO_FPATH[lang]
examples = []
with bf.BlobFile(fpath, "r") as f:

path = download_file(fpath)
with open(path, "r", encoding="utf-8") as f:
for line in f:
inputs, targets = line.strip().split("\t")
if "." in targets:
raise ValueError(f"targets {targets} contains a decimal point.")
# targets = int(targets.replace(",", ""))
examples.append({"inputs": inputs, "targets": targets, "lang": lang})
return examples

Expand Down Expand Up @@ -172,7 +177,7 @@ def fn(example: dict[str, str]):
extracted_answer = parse_answer(response_text, answer_prefix)

score = score_mgsm(correct_answer, extracted_answer)
html = common.jinja_env.from_string(HTML_JINJA).render(
html = jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
Expand All @@ -187,5 +192,5 @@ def fn(example: dict[str, str]):
metrics={language: score, latin_language: score},
)

results = common.map_with_progress(fn, self.examples)
return common.aggregate_results(results, default_stats=("mean", "std"))
results = map_with_progress(fn, self.examples)
return aggregate_results(results, default_stats=("mean", "std"))
Loading