From 975418ee3ef5263461152fa0b4244163f7791070 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Kardos?= Date: Fri, 20 Oct 2023 11:45:20 +0200 Subject: [PATCH] CLI now accepts a list of languages as its input if none are passed the benchmark will be run on all languages --- src/seb/cli.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/seb/cli.py b/src/seb/cli.py index 45a5e0e6..d2bc98b3 100644 --- a/src/seb/cli.py +++ b/src/seb/cli.py @@ -4,6 +4,7 @@ from functools import partial from pathlib import Path from statistics import mean +from typing import Optional import tabulate from sentence_transformers import SentenceTransformer @@ -42,8 +43,10 @@ def pretty_print(results: seb.BenchmarkResults): ) -def run_benchmark(model_name_or_path: str) -> seb.BenchmarkResults: - """Runs benchmark on a given model.""" +def run_benchmark( + model_name_or_path: str, languages: Optional[list[str]] +) -> seb.BenchmarkResults: + """Runs benchmark on a given model and languages.""" meta = seb.ModelMeta( name=Path(model_name_or_path).stem, ) @@ -51,7 +54,7 @@ def run_benchmark(model_name_or_path: str) -> seb.BenchmarkResults: meta=meta, loader=partial(SentenceTransformer, model_name_or_path=model_name_or_path), # type: ignore ) - benchmark = seb.Benchmark() + benchmark = seb.Benchmark(languages) res = benchmark.evaluate_model(model, raise_errors=False) return res @@ -63,15 +66,23 @@ def main(): "model_name_or_path", help="Name of the model on HuggingFace hub, or path to the model.", ) + parser.add_argument( + "languages", + nargs="*", + help="List of language codes to evaluate the model on.", + ) parser.add_argument( "--save_path", + "-o", default="benchmark_results.json", help="File to store benchmark results in.", ) args = parser.parse_args() logging.info(f"Running benchmark with {args.model_name_or_path}...") - results = run_benchmark(args.model_name_or_path) + if not args.languages: + args.languages = None + results = run_benchmark(args.model_name_or_path, args.languages) logging.info("Saving results...") save_path = Path(args.save_path) with save_path.open("w") as save_file: