From 64e01ae9d6fcf125a4ea6516263fa062b2aafeef Mon Sep 17 00:00:00 2001 From: Kenneth Enevoldsen Date: Sat, 14 Sep 2024 09:38:53 +0200 Subject: [PATCH] fix: Ensure that results are returned even when hitting cache (#1215) Fixes #1122 --- mteb/evaluation/MTEB.py | 8 +++++--- tests/test_benchmark/test_benchmark.py | 22 ++++++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/mteb/evaluation/MTEB.py b/mteb/evaluation/MTEB.py index 0ac12d4bd2..88e6b48910 100644 --- a/mteb/evaluation/MTEB.py +++ b/mteb/evaluation/MTEB.py @@ -278,7 +278,7 @@ def run( co2_tracker: bool = False, encode_kwargs: dict[str, Any] = {}, **kwargs, - ): + ) -> list[MTEBResults]: """Run the evaluation pipeline on the selected tasks. Args: @@ -335,9 +335,11 @@ def run( save_path = output_path / f"{task.metadata.name}{task.save_suffix}.json" if save_path.exists() and not overwrite_results: logger.info( - f"{task.metadata.name} results already exists. Skipping. Set overwrite_results=True to overwrite." + f"{task.metadata.name} results already exists. Loading results from disk. Set overwrite_results=True to overwrite." ) - del self.tasks[0] + mteb_results = MTEBResults.from_disk(save_path) + evaluation_results.append(mteb_results) + del self.tasks[0] # empty memory continue try: task_eval_splits = ( diff --git a/tests/test_benchmark/test_benchmark.py b/tests/test_benchmark/test_benchmark.py index 742c7930e9..ef9875f2aa 100644 --- a/tests/test_benchmark/test_benchmark.py +++ b/tests/test_benchmark/test_benchmark.py @@ -62,6 +62,28 @@ def test_benchmark_encoders_on_task(task: str | mteb.AbsTask, model: mteb.Encode eval.run(model, output_folder="tests/results", overwrite_results=True) +@pytest.mark.parametrize("task", MOCK_TASK_TEST_GRID[:1]) +@pytest.mark.parametrize("model", [MockNumpyEncoder()]) +def test_reload_results(task: str | mteb.AbsTask, model: mteb.Encoder, tmp_path: Path): + """Test that when rerunning the results are reloaded correctly""" + if isinstance(task, str): + tasks = mteb.get_tasks(tasks=[task]) + else: + tasks = [task] + + eval = mteb.MTEB(tasks=tasks) + results = eval.run(model, output_folder=str(tmp_path), overwrite_results=True) + + assert isinstance(results, list) + assert isinstance(results[0], mteb.MTEBResults) + + # reload the results + results = eval.run(model, output_folder=str(tmp_path), overwrite_results=False) + + assert isinstance(results, list) + assert isinstance(results[0], mteb.MTEBResults) + + @pytest.mark.parametrize("task_name", MOCK_TASK_TEST_GRID) def test_prompt_name_passed_to_all_encodes(task_name: str | mteb.AbsTask): """Test that all tasks correctly pass down the task_name to the encoder which supports it, and that the encoder which does not support it does not