Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

show evals in wandb weave #1522

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions evals/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Protocol, Union, runtime_checkable

import weave

from evals.prompt.base import OpenAICreateChatPrompt, OpenAICreatePrompt, Prompt
from evals.record import record_match

Expand Down Expand Up @@ -102,4 +104,13 @@ def record_and_check_match(
result["expected"] = expected
result["match"] = match
record_match(match, expected=expected, picked=picked, sampled=sampled, options=options)

prompt_0_content = prompt[0] if len(prompt) > 0 else dict()
prompt_0_content = prompt_0_content.get("content", "")

@weave.op()
def row(prompt_0_content, sampled, expected, picked, match):
return
row(prompt_0_content, sampled, expected, picked, match)

return picked
2 changes: 1 addition & 1 deletion evals/elsuite/already_said_that/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _conversation_loop(

return convo_metrics

def run(self, recorder: RecorderBase):
def _run_impl(self, recorder: RecorderBase):
samples = self._get_samples()
self.eval_all_samples(recorder, samples)
logged_metrics: list[dict] = recorder.get_metrics()
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/ballots/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def query(
else:
assert False, "Invalid influence direction"

def run(self, recorder):
def _run_impl(self, recorder):
proposals = self.get_samples()

# possibly write all prompts to disk instead of dynamically generating them
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/basic/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def eval_sample(self, sample: Any, *_):
expected=sample["ideal"],
)

def run(self, recorder):
def _run_impl(self, recorder):
samples = self.get_samples()
self.eval_all_samples(recorder, samples)
events = recorder.get_events("match")
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/basic/match_with_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def eval_sample(self, solver: Solver, sample: Any, *_):
expected=[ideal, ideal.capitalize()],
)

def run(self, recorder):
def _run_impl(self, recorder):
samples = self.get_samples()

if self.shuffle:
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/bluff/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _get_player_info(self, player: Player) -> str:
else:
return type(player).__name__

def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]:
def _run_impl(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]:
samples = list(range(self.n_samples))
self.eval_all_samples(recorder, samples)
metrics = recorder.get_metrics()
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/bugged_tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def eval_sample(self, solver: Solver, sample: Any, rng: random.Random):

evals.record.record_metrics(**metrics) # type: ignore (evals.record badly hinted)

def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: # type: ignore (evals.record badly hinted)
def _run_impl(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: # type: ignore (evals.record badly hinted)
samples = self.get_samples()

self.eval_all_samples(recorder, samples)
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/cant_do_that_anymore/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_solver_pred(

evals.record.record_metrics(**metrics)

def run(self, recorder: RecorderBase) -> dict[str, Union[float, int]]:
def _run_impl(self, recorder: RecorderBase) -> dict[str, Union[float, int]]:
if self.diagonal_variation:
self.samples_jsonl = get_diagonal_dataset_path(
registry_path=self._prefix_registry_path("")
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/error_recovery/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _get_answer(
answer = self._extract_final_answer(solver=solver, task_state=task_state, sample=sample)
return answer

def run(self, recorder: evals.record.Recorder):
def _run_impl(self, recorder: evals.record.Recorder):
samples = self.get_samples()

self.eval_all_samples(recorder, samples)
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/function_deduction/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def eval_sample(self, solver: Solver, sample: Sample, rng: random.Random):
complexity=sample.complexity,
)

def run(self, recorder: evals.record.Recorder):
def _run_impl(self, recorder: evals.record.Recorder):
samples = self.get_samples()

# Add copies according to self.n_repeat
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/hr_ml_agent_bench/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def eval_sample(self, solver: Solver, raw_sample: dict, rng: Random) -> None:
model_score_humanrelative=result.model_score_humanrelative,
)

def run(self, recorder: Recorder) -> dict:
def _run_impl(self, recorder: Recorder) -> dict:
samples = self.get_samples()
self.eval_all_samples(recorder, samples)
metrics = recorder.get_metrics()
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/identifying_variables/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def eval_sample(self, solver: Solver, sample: Sample, rng: random.Random) -> Non
num_not_ctrl=num_not_ctrl,
)

def run(self, recorder: RecorderBase) -> Dict[str, float]:
def _run_impl(self, recorder: RecorderBase) -> Dict[str, float]:
samples: List[Dict] = self._get_samples()
self.rng.shuffle(samples)
self.eval_all_samples(recorder, samples)
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/incontext_rl/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def eval_sample(self, solver: Solver, sample: Any, rng: random.Random):
episode_end_steps=ts.current_state.episode_end_steps,
)

def run(self, recorder: evals.record.Recorder):
def _run_impl(self, recorder: evals.record.Recorder):
samples = self.get_samples()
for sample in samples:
# Create environments and pass them to each thread via the sample
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/make_me_pay/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def eval_sample(self, solver: Solver, sample: Any, *_):
mark_withdraw=results["mark_withdraw"],
)

def run(self, recorder: RecorderBase):
def _run_impl(self, recorder: RecorderBase):
# We don't actually need any samples, we just need to run the eval num_experiments times.
experiment_ids = list(range(self.num_experiments))
self.eval_all_samples(recorder, experiment_ids)
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/multistep_web_tasks/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def eval_sample(self, solver: Solver, sample: dict, rng: Any) -> None:
trajectory_length=len(result.trajectory),
)

def run(self, recorder: RecorderBase):
def _run_impl(self, recorder: RecorderBase):
samples = self.get_samples()
self.session.add_samples(samples)
# with statement handles setting up docker containers and tearing them down on completion/error
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/sandbagging/mmlu_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def eval_sample(
extra_logging=extra_logging,
)

def run(self, recorder: evals.record.Recorder):
def _run_impl(self, recorder: evals.record.Recorder):
samples = self.get_samples()

self.eval_all_samples(recorder, samples)
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/sandbagging/sandbagging_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def eval_sample(self, solver: Solver, sample: Dict[str, Any], rng: random.Random

self.mmlu_eval_sample(solver, sample, rng, extra_logging)

def run(self, recorder: evals.record.Recorder):
def _run_impl(self, recorder: evals.record.Recorder):
metrics = {}
achieved_accs = []
for target, mmlu_eval in zip(self.target_accuracies, self.evals):
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/schelling_point/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def eval_sample(self, sample: Any, *_):
is_runtime_error=False,
)

def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]:
def _run_impl(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]:

samples = self.get_samples()[0 : self.n_samples]

Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/self_prompting/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def normalized_improvement(current, baseline):
logger.info(f"Improvement scores: {improvement_scores}")
return improvement_scores

def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]:
def _run_impl(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]:
samples = self.get_samples()

# Shuffle and limit samples
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/skill_acquisition/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _eval_retrieval_sample(self, solver: Solver, sample: Dict, *_) -> Dict[str,
}
return out_obj

def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]:
def _run_impl(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]:
samples = self.get_samples()
self.rng.shuffle(samples)
samples = samples[: self.n_samples] if self.n_samples is not None else samples
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/steganography/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def eval_sample(self, sample: Any, *_):
rule_violated=results["rule_violated"],
)

def run(self, recorder: RecorderBase):
def _run_impl(self, recorder: RecorderBase):
samples = self.get_samples()
self.eval_all_samples(recorder, samples)
metrics = recorder.get_metrics()
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/text_compression/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def eval_sample(self, sample: Any, *_):
semantic_distance=results["semantic_distance"],
)

def run(self, recorder: RecorderBase):
def _run_impl(self, recorder: RecorderBase):
samples = self.get_samples()
self.eval_all_samples(recorder, samples)
metrics = recorder.get_metrics()
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/track_the_stat/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _eval_sample(self, solver: Solver, capped_inf_list: list[int]) -> dict:
"violation": violation,
}

def run(self, recorder: RecorderBase):
def _run_impl(self, recorder: RecorderBase):
samples = self._get_samples()
self.eval_all_samples(recorder, samples)
logged_metrics: list[dict] = recorder.get_metrics()
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/twenty_questions/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def eval_sample(self, solver: Solver, sample: Dict, rng: random.Random) -> Dict[

return response

def run(self, recorder: Recorder) -> Dict[str, Union[float, int]]:
def _run_impl(self, recorder: Recorder) -> Dict[str, Union[float, int]]:
samples = self.get_samples()
self.rng.shuffle(samples)
samples = samples[: self.n_samples] if self.n_samples else samples
Expand Down
11 changes: 9 additions & 2 deletions evals/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from .solvers.solver import Solver
from .solvers.utils import maybe_wrap_with_compl_fn, maybe_wrap_with_solver

import weave

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -82,10 +84,15 @@ def completion_fn(self) -> CompletionFn:
"""Helper for more ergonomic access to a single CompletionFn."""
return self.completion_fns[0]

@abc.abstractmethod
def run(self, recorder: RecorderBase) -> Dict[str, float]:
"""Run the evaluation with the corresponding recorder."""
raise NotImplementedError()
weave.init("yovaluate")

@weave.op()
def yovaluate() -> Dict[str, Any]:
return self._run_impl(recorder)

return yovaluate()

async def async_eval_all_samples(
self,
Expand Down
13 changes: 13 additions & 0 deletions evals/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from evals.utils.misc import t
from evals.utils.snowflake import SnowflakeConnection

import weave

logger = logging.getLogger(__name__)

MIN_FLUSH_EVENTS = 100
Expand Down Expand Up @@ -184,6 +186,7 @@ def record_event(self, type, data=None, sample_id=None):
self._flushes_started += 1
self._flush_events_internal(events_to_write)

@weave.op()
def record_match(self, correct: bool, *, expected=None, picked=None, sample_id=None, **extra):
assert isinstance(
correct, bool
Expand All @@ -199,6 +202,7 @@ def record_match(self, correct: bool, *, expected=None, picked=None, sample_id=N
}
self.record_event("match", data, sample_id=sample_id)

@weave.op()
def record_embedding(self, prompt, embedding_type, sample_id=None, **extra):
data = {
"prompt": prompt,
Expand All @@ -207,6 +211,7 @@ def record_embedding(self, prompt, embedding_type, sample_id=None, **extra):
}
self.record_event("embedding", data, sample_id=sample_id)

@weave.op()
def record_sampling(self, prompt, sampled, sample_id=None, **extra):
data = {
"prompt": prompt,
Expand All @@ -215,6 +220,7 @@ def record_sampling(self, prompt, sampled, sample_id=None, **extra):
}
self.record_event("sampling", data, sample_id=sample_id)

@weave.op()
def record_function_call(self, name, arguments, return_value, sample_id=None, **extra):
data = {
"name": name,
Expand All @@ -224,6 +230,7 @@ def record_function_call(self, name, arguments, return_value, sample_id=None, **
}
self.record_event("function_call", data, sample_id=sample_id)

@weave.op()
def record_cond_logp(self, prompt, completion, logp, sample_id=None, **extra):
data = {
"prompt": prompt,
Expand All @@ -233,6 +240,7 @@ def record_cond_logp(self, prompt, completion, logp, sample_id=None, **extra):
}
self.record_event("cond_logp", data, sample_id=sample_id)

@weave.op()
def record_pick_option(self, prompt, options, picked, sample_id=None, **extra):
data = {
"prompt": prompt,
Expand All @@ -242,12 +250,15 @@ def record_pick_option(self, prompt, options, picked, sample_id=None, **extra):
}
self.record_event("pick_option", data, sample_id=sample_id)

@weave.op()
def record_raw(self, data):
self.record_event("raw_sample", data)

@weave.op()
def record_metrics(self, **kwargs):
self.record_event("metrics", kwargs)

@weave.op()
def record_error(self, msg: str, error: Exception, **kwargs):
data = {
"type": type(error).__name__,
Expand All @@ -256,9 +267,11 @@ def record_error(self, msg: str, error: Exception, **kwargs):
data.update(kwargs)
self.record_event("error", data)

@weave.op()
def record_extra(self, data, sample_id=None):
self.record_event("extra", data, sample_id=sample_id)

@weave.op()
def record_final_report(self, final_report: Any):
logging.info(f"Final report: {final_report}. Not writing anywhere.")

Expand Down