Skip to content

Commit

Permalink
fix: Ensure that results are returned even when hitting cache (#1215)
Browse files Browse the repository at this point in the history
Fixes #1122
  • Loading branch information
KennethEnevoldsen authored Sep 14, 2024
1 parent 88b4f6e commit 64e01ae
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
8 changes: 5 additions & 3 deletions mteb/evaluation/MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def run(
co2_tracker: bool = False,
encode_kwargs: dict[str, Any] = {},
**kwargs,
):
) -> list[MTEBResults]:
"""Run the evaluation pipeline on the selected tasks.
Args:
Expand Down Expand Up @@ -335,9 +335,11 @@ def run(
save_path = output_path / f"{task.metadata.name}{task.save_suffix}.json"
if save_path.exists() and not overwrite_results:
logger.info(
f"{task.metadata.name} results already exists. Skipping. Set overwrite_results=True to overwrite."
f"{task.metadata.name} results already exists. Loading results from disk. Set overwrite_results=True to overwrite."
)
del self.tasks[0]
mteb_results = MTEBResults.from_disk(save_path)
evaluation_results.append(mteb_results)
del self.tasks[0] # empty memory
continue
try:
task_eval_splits = (
Expand Down
22 changes: 22 additions & 0 deletions tests/test_benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,28 @@ def test_benchmark_encoders_on_task(task: str | mteb.AbsTask, model: mteb.Encode
eval.run(model, output_folder="tests/results", overwrite_results=True)


@pytest.mark.parametrize("task", MOCK_TASK_TEST_GRID[:1])
@pytest.mark.parametrize("model", [MockNumpyEncoder()])
def test_reload_results(task: str | mteb.AbsTask, model: mteb.Encoder, tmp_path: Path):
"""Test that when rerunning the results are reloaded correctly"""
if isinstance(task, str):
tasks = mteb.get_tasks(tasks=[task])
else:
tasks = [task]

eval = mteb.MTEB(tasks=tasks)
results = eval.run(model, output_folder=str(tmp_path), overwrite_results=True)

assert isinstance(results, list)
assert isinstance(results[0], mteb.MTEBResults)

# reload the results
results = eval.run(model, output_folder=str(tmp_path), overwrite_results=False)

assert isinstance(results, list)
assert isinstance(results[0], mteb.MTEBResults)


@pytest.mark.parametrize("task_name", MOCK_TASK_TEST_GRID)
def test_prompt_name_passed_to_all_encodes(task_name: str | mteb.AbsTask):
"""Test that all tasks correctly pass down the task_name to the encoder which supports it, and that the encoder which does not support it does not
Expand Down

0 comments on commit 64e01ae

Please sign in to comment.