From 068511c6437b03be50793c162754c83b3b4befe4 Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Tue, 3 Dec 2024 16:19:43 +0000 Subject: [PATCH] refactor: save out predictions in eval script --- kazu/training/evaluate_script.py | 9 +++++++++ .../examples/conf/multilabel_ner_evaluate/default.yaml | 1 + 2 files changed, 10 insertions(+) diff --git a/kazu/training/evaluate_script.py b/kazu/training/evaluate_script.py index 1ebe25ed..1c441d8f 100644 --- a/kazu/training/evaluate_script.py +++ b/kazu/training/evaluate_script.py @@ -11,6 +11,7 @@ from hydra.utils import instantiate from omegaconf import DictConfig +from kazu.data import Document from kazu.pipeline import Pipeline from kazu.steps.ner.hf_token_classification import ( TransformersModelForTokenClassificationNerStep, @@ -30,6 +31,13 @@ from kazu.utils.constants import HYDRA_VERSION_BASE +def save_out_predictions(output_dir: Path, documents: list[Document]) -> None: + for doc in documents: + file_path = output_dir / f"{doc.idx}.json" + with file_path.open("w") as f: + f.write(doc.to_json()) + + @hydra.main( version_base=HYDRA_VERSION_BASE, config_path=str( @@ -64,6 +72,7 @@ def main(cfg: DictConfig) -> None: pipeline(documents) print(f"Predicted {len(documents)} documents in {time.time() - start:.2f} seconds.") + save_out_predictions(Path(cfg.predictions_dir), documents) print("Calculating metrics") metrics, _ = calculate_metrics(0, documents, label_list) with open(Path(prediction_config.path) / "test_metrics.json", "w") as file: diff --git a/scripts/examples/conf/multilabel_ner_evaluate/default.yaml b/scripts/examples/conf/multilabel_ner_evaluate/default.yaml index 4b7f3064..6b7357d9 100644 --- a/scripts/examples/conf/multilabel_ner_evaluate/default.yaml +++ b/scripts/examples/conf/multilabel_ner_evaluate/default.yaml @@ -8,6 +8,7 @@ prediction_config: device: cpu architecture: bert use_multilabel: true +predictions_dir: ??? css_colors: - "#000000" # Black - "#FF0000" # Red