Skip to content

Commit

Permalink
Add better formatting for Eleuther eval results (#986)
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings authored May 16, 2024
1 parent 618e80a commit a6f4cc5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
8 changes: 5 additions & 3 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from lm_eval.evaluator import evaluate
from lm_eval.models.huggingface import HFLM
from lm_eval.tasks import get_task_dict
from lm_eval.utils import make_table
except ImportError:
logger.error(
"Recipe requires EleutherAI Eval Harness v0.4. Please install with `pip install lm_eval==0.4.*`"
Expand Down Expand Up @@ -187,15 +188,16 @@ def evaluate(self) -> None:

task_dict = get_task_dict(self._tasks)
logger.info(f"Running evaluation on {self._tasks} tasks.")
eleuther_output = evaluate(
output = evaluate(
model_eval_wrapper,
task_dict,
limit=self._limit,
)

logger.info(f"Eval completed in {time.time() - t1:.02f} seconds.")
for task, res in eleuther_output["results"].items():
logger.info(f"{task}: {res}")

formatted_output = make_table(output)
print(formatted_output)


@config.parse
Expand Down
11 changes: 6 additions & 5 deletions tests/recipes/test_eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class TestEleutherEval:
@pytest.mark.integration_test
def test_torchtune_checkpoint_eval_results(self, caplog, monkeypatch, tmpdir):
def test_torchtune_checkpoint_eval_results(self, capsys, monkeypatch, tmpdir):
ckpt = "small_test_ckpt_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
Expand All @@ -47,10 +47,11 @@ def test_torchtune_checkpoint_eval_results(self, caplog, monkeypatch, tmpdir):
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

err_log = caplog.messages[-1]
log_search_results = re.search(r"'acc,none': (\d+\.\d+)", err_log)
assert log_search_results is not None
acc_result = float(log_search_results.group(1))
out = capsys.readouterr().out

search_results = re.search(r"acc(?:_norm)?\s*\|?\s*([\d.]+)", out.strip())
assert search_results is not None
acc_result = float(search_results.group(1))
assert math.isclose(acc_result, 0.3, abs_tol=0.05)

@pytest.fixture
Expand Down

0 comments on commit a6f4cc5

Please sign in to comment.