Skip to content

Commit

Permalink
Add saving predictions to jsonl
Browse files Browse the repository at this point in the history
  • Loading branch information
lsz05 committed Aug 7, 2024
1 parent 7881bd7 commit 7744381
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/jmteb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def main(
dataset_name=eval_name,
task_name=evaluator.__class__.__name__.replace("Evaluator", ""),
)
if getattr(evaluator, "log_predictions", False):
score_recorder.record_predictions(
metrics, eval_name, evaluator.__class__.__name__.replace("Evaluator", "")
)

logger.info(f"Results for {eval_name}\n{json.dumps(metrics.as_dict(), indent=4, ensure_ascii=False)}")

Expand Down
14 changes: 14 additions & 0 deletions src/jmteb/utils/score_recorder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from dataclasses import asdict
import json
from abc import ABC, abstractmethod
from collections import defaultdict
Expand Down Expand Up @@ -30,6 +31,12 @@ def save_to_json(scores: EvaluationResults | dict[Any, Any], filename: str | Pat
with open(filename, "w") as fout:
json.dump(scores, fout, indent=4, ensure_ascii=False)

@staticmethod
def save_prediction_to_jsonl(predictions: list[Any], filename: str | PathLike[str]) -> None:
with open(filename, "w") as fout:
for prediction in predictions:
fout.write(json.dumps(asdict(prediction), ensure_ascii=False) + "\n")

def record_task_scores(self, scores: EvaluationResults, dataset_name: str, task_name: str) -> None:
if not self.save_dir:
return
Expand All @@ -39,6 +46,13 @@ def record_task_scores(self, scores: EvaluationResults, dataset_name: str, task_
self.scores[task_name][dataset_name] = scores
self.save_to_json(self.scores[task_name][dataset_name].as_dict(), save_filename)

def record_predictions(self, results: EvaluationResults, dataset_name: str, task_name: str) -> None:
if not self.save_dir:
return
save_filename = Path(self.save_dir) / task_name / f"predictions_{dataset_name}.jsonl"
save_filename.parent.mkdir(parents=True, exist_ok=True)
self.save_prediction_to_jsonl(results.predictions, save_filename)

def record_summary(self):
if not self.save_dir:
return
Expand Down

0 comments on commit 7744381

Please sign in to comment.