diff --git a/evals/api.py b/evals/api.py index bbb6b7c728..a1e6628d3c 100644 --- a/evals/api.py +++ b/evals/api.py @@ -7,6 +7,8 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Optional, Protocol, Union, runtime_checkable +import weave + from evals.prompt.base import OpenAICreateChatPrompt, OpenAICreatePrompt, Prompt from evals.record import record_match @@ -102,4 +104,13 @@ def record_and_check_match( result["expected"] = expected result["match"] = match record_match(match, expected=expected, picked=picked, sampled=sampled, options=options) + + prompt_0_content = prompt[0] if len(prompt) > 0 else dict() + prompt_0_content = prompt_0_content.get("content", "") + + @weave.op() + def row(prompt_0_content, sampled, expected, picked, match): + return + row(prompt_0_content, sampled, expected, picked, match) + return picked diff --git a/evals/elsuite/already_said_that/eval.py b/evals/elsuite/already_said_that/eval.py index 2fa495c702..c6d78b1100 100644 --- a/evals/elsuite/already_said_that/eval.py +++ b/evals/elsuite/already_said_that/eval.py @@ -115,7 +115,7 @@ def _conversation_loop( return convo_metrics - def run(self, recorder: RecorderBase): + def _run_impl(self, recorder: RecorderBase): samples = self._get_samples() self.eval_all_samples(recorder, samples) logged_metrics: list[dict] = recorder.get_metrics() diff --git a/evals/elsuite/ballots/eval.py b/evals/elsuite/ballots/eval.py index 67c44567b6..96cb0c0ec9 100644 --- a/evals/elsuite/ballots/eval.py +++ b/evals/elsuite/ballots/eval.py @@ -158,7 +158,7 @@ def query( else: assert False, "Invalid influence direction" - def run(self, recorder): + def _run_impl(self, recorder): proposals = self.get_samples() # possibly write all prompts to disk instead of dynamically generating them diff --git a/evals/elsuite/basic/match.py b/evals/elsuite/basic/match.py index ac72f72b37..e9f347dc3d 100644 --- a/evals/elsuite/basic/match.py +++ b/evals/elsuite/basic/match.py @@ -55,7 +55,7 @@ def eval_sample(self, sample: Any, *_): expected=sample["ideal"], ) - def run(self, recorder): + def _run_impl(self, recorder): samples = self.get_samples() self.eval_all_samples(recorder, samples) events = recorder.get_events("match") diff --git a/evals/elsuite/basic/match_with_solvers.py b/evals/elsuite/basic/match_with_solvers.py index 2feb57658d..b2ae8ee29d 100644 --- a/evals/elsuite/basic/match_with_solvers.py +++ b/evals/elsuite/basic/match_with_solvers.py @@ -65,7 +65,7 @@ def eval_sample(self, solver: Solver, sample: Any, *_): expected=[ideal, ideal.capitalize()], ) - def run(self, recorder): + def _run_impl(self, recorder): samples = self.get_samples() if self.shuffle: diff --git a/evals/elsuite/bluff/eval.py b/evals/elsuite/bluff/eval.py index 29d7e9cd92..84b8d25063 100644 --- a/evals/elsuite/bluff/eval.py +++ b/evals/elsuite/bluff/eval.py @@ -76,7 +76,7 @@ def _get_player_info(self, player: Player) -> str: else: return type(player).__name__ - def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: + def _run_impl(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: samples = list(range(self.n_samples)) self.eval_all_samples(recorder, samples) metrics = recorder.get_metrics() diff --git a/evals/elsuite/bugged_tools/eval.py b/evals/elsuite/bugged_tools/eval.py index 38cbccd594..cd12b98403 100644 --- a/evals/elsuite/bugged_tools/eval.py +++ b/evals/elsuite/bugged_tools/eval.py @@ -109,7 +109,7 @@ def eval_sample(self, solver: Solver, sample: Any, rng: random.Random): evals.record.record_metrics(**metrics) # type: ignore (evals.record badly hinted) - def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: # type: ignore (evals.record badly hinted) + def _run_impl(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: # type: ignore (evals.record badly hinted) samples = self.get_samples() self.eval_all_samples(recorder, samples) diff --git a/evals/elsuite/cant_do_that_anymore/eval.py b/evals/elsuite/cant_do_that_anymore/eval.py index 0ca6df5b0b..c6d1ea5402 100644 --- a/evals/elsuite/cant_do_that_anymore/eval.py +++ b/evals/elsuite/cant_do_that_anymore/eval.py @@ -112,7 +112,7 @@ def get_solver_pred( evals.record.record_metrics(**metrics) - def run(self, recorder: RecorderBase) -> dict[str, Union[float, int]]: + def _run_impl(self, recorder: RecorderBase) -> dict[str, Union[float, int]]: if self.diagonal_variation: self.samples_jsonl = get_diagonal_dataset_path( registry_path=self._prefix_registry_path("") diff --git a/evals/elsuite/error_recovery/eval.py b/evals/elsuite/error_recovery/eval.py index 89512179fe..f1b3e44795 100644 --- a/evals/elsuite/error_recovery/eval.py +++ b/evals/elsuite/error_recovery/eval.py @@ -217,7 +217,7 @@ def _get_answer( answer = self._extract_final_answer(solver=solver, task_state=task_state, sample=sample) return answer - def run(self, recorder: evals.record.Recorder): + def _run_impl(self, recorder: evals.record.Recorder): samples = self.get_samples() self.eval_all_samples(recorder, samples) diff --git a/evals/elsuite/function_deduction/eval.py b/evals/elsuite/function_deduction/eval.py index 6542852153..e3afa7556c 100644 --- a/evals/elsuite/function_deduction/eval.py +++ b/evals/elsuite/function_deduction/eval.py @@ -148,7 +148,7 @@ def eval_sample(self, solver: Solver, sample: Sample, rng: random.Random): complexity=sample.complexity, ) - def run(self, recorder: evals.record.Recorder): + def _run_impl(self, recorder: evals.record.Recorder): samples = self.get_samples() # Add copies according to self.n_repeat diff --git a/evals/elsuite/hr_ml_agent_bench/eval.py b/evals/elsuite/hr_ml_agent_bench/eval.py index 611be17790..d34fd195ef 100644 --- a/evals/elsuite/hr_ml_agent_bench/eval.py +++ b/evals/elsuite/hr_ml_agent_bench/eval.py @@ -97,7 +97,7 @@ def eval_sample(self, solver: Solver, raw_sample: dict, rng: Random) -> None: model_score_humanrelative=result.model_score_humanrelative, ) - def run(self, recorder: Recorder) -> dict: + def _run_impl(self, recorder: Recorder) -> dict: samples = self.get_samples() self.eval_all_samples(recorder, samples) metrics = recorder.get_metrics() diff --git a/evals/elsuite/identifying_variables/eval.py b/evals/elsuite/identifying_variables/eval.py index 31b3b743e0..c221ec686f 100644 --- a/evals/elsuite/identifying_variables/eval.py +++ b/evals/elsuite/identifying_variables/eval.py @@ -87,7 +87,7 @@ def eval_sample(self, solver: Solver, sample: Sample, rng: random.Random) -> Non num_not_ctrl=num_not_ctrl, ) - def run(self, recorder: RecorderBase) -> Dict[str, float]: + def _run_impl(self, recorder: RecorderBase) -> Dict[str, float]: samples: List[Dict] = self._get_samples() self.rng.shuffle(samples) self.eval_all_samples(recorder, samples) diff --git a/evals/elsuite/incontext_rl/eval.py b/evals/elsuite/incontext_rl/eval.py index a1fac2101e..346930f5f5 100644 --- a/evals/elsuite/incontext_rl/eval.py +++ b/evals/elsuite/incontext_rl/eval.py @@ -147,7 +147,7 @@ def eval_sample(self, solver: Solver, sample: Any, rng: random.Random): episode_end_steps=ts.current_state.episode_end_steps, ) - def run(self, recorder: evals.record.Recorder): + def _run_impl(self, recorder: evals.record.Recorder): samples = self.get_samples() for sample in samples: # Create environments and pass them to each thread via the sample diff --git a/evals/elsuite/make_me_pay/eval.py b/evals/elsuite/make_me_pay/eval.py index 9b2b8b1275..1012971f34 100644 --- a/evals/elsuite/make_me_pay/eval.py +++ b/evals/elsuite/make_me_pay/eval.py @@ -99,7 +99,7 @@ def eval_sample(self, solver: Solver, sample: Any, *_): mark_withdraw=results["mark_withdraw"], ) - def run(self, recorder: RecorderBase): + def _run_impl(self, recorder: RecorderBase): # We don't actually need any samples, we just need to run the eval num_experiments times. experiment_ids = list(range(self.num_experiments)) self.eval_all_samples(recorder, experiment_ids) diff --git a/evals/elsuite/multistep_web_tasks/eval.py b/evals/elsuite/multistep_web_tasks/eval.py index 2cd7289e76..0aacc2d391 100644 --- a/evals/elsuite/multistep_web_tasks/eval.py +++ b/evals/elsuite/multistep_web_tasks/eval.py @@ -47,7 +47,7 @@ def eval_sample(self, solver: Solver, sample: dict, rng: Any) -> None: trajectory_length=len(result.trajectory), ) - def run(self, recorder: RecorderBase): + def _run_impl(self, recorder: RecorderBase): samples = self.get_samples() self.session.add_samples(samples) # with statement handles setting up docker containers and tearing them down on completion/error diff --git a/evals/elsuite/sandbagging/mmlu_eval.py b/evals/elsuite/sandbagging/mmlu_eval.py index ae421d8f62..68f9c55676 100644 --- a/evals/elsuite/sandbagging/mmlu_eval.py +++ b/evals/elsuite/sandbagging/mmlu_eval.py @@ -61,7 +61,7 @@ def eval_sample( extra_logging=extra_logging, ) - def run(self, recorder: evals.record.Recorder): + def _run_impl(self, recorder: evals.record.Recorder): samples = self.get_samples() self.eval_all_samples(recorder, samples) diff --git a/evals/elsuite/sandbagging/sandbagging_eval.py b/evals/elsuite/sandbagging/sandbagging_eval.py index 675341a207..88cdfaf847 100644 --- a/evals/elsuite/sandbagging/sandbagging_eval.py +++ b/evals/elsuite/sandbagging/sandbagging_eval.py @@ -53,7 +53,7 @@ def eval_sample(self, solver: Solver, sample: Dict[str, Any], rng: random.Random self.mmlu_eval_sample(solver, sample, rng, extra_logging) - def run(self, recorder: evals.record.Recorder): + def _run_impl(self, recorder: evals.record.Recorder): metrics = {} achieved_accs = [] for target, mmlu_eval in zip(self.target_accuracies, self.evals): diff --git a/evals/elsuite/schelling_point/eval.py b/evals/elsuite/schelling_point/eval.py index 46d5371af1..233971040b 100644 --- a/evals/elsuite/schelling_point/eval.py +++ b/evals/elsuite/schelling_point/eval.py @@ -75,7 +75,7 @@ def eval_sample(self, sample: Any, *_): is_runtime_error=False, ) - def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: + def _run_impl(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: samples = self.get_samples()[0 : self.n_samples] diff --git a/evals/elsuite/self_prompting/eval.py b/evals/elsuite/self_prompting/eval.py index 7db858f5d4..90c9485170 100644 --- a/evals/elsuite/self_prompting/eval.py +++ b/evals/elsuite/self_prompting/eval.py @@ -177,7 +177,7 @@ def normalized_improvement(current, baseline): logger.info(f"Improvement scores: {improvement_scores}") return improvement_scores - def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: + def _run_impl(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: samples = self.get_samples() # Shuffle and limit samples diff --git a/evals/elsuite/skill_acquisition/eval.py b/evals/elsuite/skill_acquisition/eval.py index 52c770db7d..919cea30ed 100644 --- a/evals/elsuite/skill_acquisition/eval.py +++ b/evals/elsuite/skill_acquisition/eval.py @@ -186,7 +186,7 @@ def _eval_retrieval_sample(self, solver: Solver, sample: Dict, *_) -> Dict[str, } return out_obj - def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: + def _run_impl(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: samples = self.get_samples() self.rng.shuffle(samples) samples = samples[: self.n_samples] if self.n_samples is not None else samples diff --git a/evals/elsuite/steganography/eval.py b/evals/elsuite/steganography/eval.py index e25e1bc551..0eaa074109 100644 --- a/evals/elsuite/steganography/eval.py +++ b/evals/elsuite/steganography/eval.py @@ -65,7 +65,7 @@ def eval_sample(self, sample: Any, *_): rule_violated=results["rule_violated"], ) - def run(self, recorder: RecorderBase): + def _run_impl(self, recorder: RecorderBase): samples = self.get_samples() self.eval_all_samples(recorder, samples) metrics = recorder.get_metrics() diff --git a/evals/elsuite/text_compression/eval.py b/evals/elsuite/text_compression/eval.py index d2a620941b..603367f008 100644 --- a/evals/elsuite/text_compression/eval.py +++ b/evals/elsuite/text_compression/eval.py @@ -46,7 +46,7 @@ def eval_sample(self, sample: Any, *_): semantic_distance=results["semantic_distance"], ) - def run(self, recorder: RecorderBase): + def _run_impl(self, recorder: RecorderBase): samples = self.get_samples() self.eval_all_samples(recorder, samples) metrics = recorder.get_metrics() diff --git a/evals/elsuite/track_the_stat/eval.py b/evals/elsuite/track_the_stat/eval.py index d1ca65d719..b86de83c45 100644 --- a/evals/elsuite/track_the_stat/eval.py +++ b/evals/elsuite/track_the_stat/eval.py @@ -67,7 +67,7 @@ def _eval_sample(self, solver: Solver, capped_inf_list: list[int]) -> dict: "violation": violation, } - def run(self, recorder: RecorderBase): + def _run_impl(self, recorder: RecorderBase): samples = self._get_samples() self.eval_all_samples(recorder, samples) logged_metrics: list[dict] = recorder.get_metrics() diff --git a/evals/elsuite/twenty_questions/eval.py b/evals/elsuite/twenty_questions/eval.py index 3cb0d5c857..ddaa51076a 100644 --- a/evals/elsuite/twenty_questions/eval.py +++ b/evals/elsuite/twenty_questions/eval.py @@ -75,7 +75,7 @@ def eval_sample(self, solver: Solver, sample: Dict, rng: random.Random) -> Dict[ return response - def run(self, recorder: Recorder) -> Dict[str, Union[float, int]]: + def _run_impl(self, recorder: Recorder) -> Dict[str, Union[float, int]]: samples = self.get_samples() self.rng.shuffle(samples) samples = samples[: self.n_samples] if self.n_samples else samples diff --git a/evals/eval.py b/evals/eval.py index cce0c75c3f..8420333b2f 100644 --- a/evals/eval.py +++ b/evals/eval.py @@ -20,6 +20,8 @@ from .solvers.solver import Solver from .solvers.utils import maybe_wrap_with_compl_fn, maybe_wrap_with_solver +import weave + logger = logging.getLogger(__name__) @@ -82,10 +84,15 @@ def completion_fn(self) -> CompletionFn: """Helper for more ergonomic access to a single CompletionFn.""" return self.completion_fns[0] - @abc.abstractmethod def run(self, recorder: RecorderBase) -> Dict[str, float]: """Run the evaluation with the corresponding recorder.""" - raise NotImplementedError() + weave.init("yovaluate") + + @weave.op() + def yovaluate() -> Dict[str, Any]: + return self._run_impl(recorder) + + return yovaluate() async def async_eval_all_samples( self, diff --git a/evals/record.py b/evals/record.py index 8e8ebe9ae6..8ed95816a0 100644 --- a/evals/record.py +++ b/evals/record.py @@ -25,6 +25,8 @@ from evals.utils.misc import t from evals.utils.snowflake import SnowflakeConnection +import weave + logger = logging.getLogger(__name__) MIN_FLUSH_EVENTS = 100 @@ -184,6 +186,7 @@ def record_event(self, type, data=None, sample_id=None): self._flushes_started += 1 self._flush_events_internal(events_to_write) + @weave.op() def record_match(self, correct: bool, *, expected=None, picked=None, sample_id=None, **extra): assert isinstance( correct, bool @@ -199,6 +202,7 @@ def record_match(self, correct: bool, *, expected=None, picked=None, sample_id=N } self.record_event("match", data, sample_id=sample_id) + @weave.op() def record_embedding(self, prompt, embedding_type, sample_id=None, **extra): data = { "prompt": prompt, @@ -207,6 +211,7 @@ def record_embedding(self, prompt, embedding_type, sample_id=None, **extra): } self.record_event("embedding", data, sample_id=sample_id) + @weave.op() def record_sampling(self, prompt, sampled, sample_id=None, **extra): data = { "prompt": prompt, @@ -215,6 +220,7 @@ def record_sampling(self, prompt, sampled, sample_id=None, **extra): } self.record_event("sampling", data, sample_id=sample_id) + @weave.op() def record_function_call(self, name, arguments, return_value, sample_id=None, **extra): data = { "name": name, @@ -224,6 +230,7 @@ def record_function_call(self, name, arguments, return_value, sample_id=None, ** } self.record_event("function_call", data, sample_id=sample_id) + @weave.op() def record_cond_logp(self, prompt, completion, logp, sample_id=None, **extra): data = { "prompt": prompt, @@ -233,6 +240,7 @@ def record_cond_logp(self, prompt, completion, logp, sample_id=None, **extra): } self.record_event("cond_logp", data, sample_id=sample_id) + @weave.op() def record_pick_option(self, prompt, options, picked, sample_id=None, **extra): data = { "prompt": prompt, @@ -242,12 +250,15 @@ def record_pick_option(self, prompt, options, picked, sample_id=None, **extra): } self.record_event("pick_option", data, sample_id=sample_id) + @weave.op() def record_raw(self, data): self.record_event("raw_sample", data) + @weave.op() def record_metrics(self, **kwargs): self.record_event("metrics", kwargs) + @weave.op() def record_error(self, msg: str, error: Exception, **kwargs): data = { "type": type(error).__name__, @@ -256,9 +267,11 @@ def record_error(self, msg: str, error: Exception, **kwargs): data.update(kwargs) self.record_event("error", data) + @weave.op() def record_extra(self, data, sample_id=None): self.record_event("extra", data, sample_id=sample_id) + @weave.op() def record_final_report(self, final_report: Any): logging.info(f"Final report: {final_report}. Not writing anywhere.")