Skip to content

Commit

Permalink
CLI now accepts a list of languages as its input if none are passed t…
Browse files Browse the repository at this point in the history
…he benchmark will be run on all languages
  • Loading branch information
x-tabdeveloping committed Oct 20, 2023
1 parent b5d2f6b commit 975418e
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/seb/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import partial
from pathlib import Path
from statistics import mean
from typing import Optional

import tabulate
from sentence_transformers import SentenceTransformer
Expand Down Expand Up @@ -42,16 +43,18 @@ def pretty_print(results: seb.BenchmarkResults):
)


def run_benchmark(model_name_or_path: str) -> seb.BenchmarkResults:
"""Runs benchmark on a given model."""
def run_benchmark(
model_name_or_path: str, languages: Optional[list[str]]
) -> seb.BenchmarkResults:
"""Runs benchmark on a given model and languages."""
meta = seb.ModelMeta(
name=Path(model_name_or_path).stem,
)
model = seb.SebModel(
meta=meta,
loader=partial(SentenceTransformer, model_name_or_path=model_name_or_path), # type: ignore
)
benchmark = seb.Benchmark()
benchmark = seb.Benchmark(languages)
res = benchmark.evaluate_model(model, raise_errors=False)
return res

Expand All @@ -63,15 +66,23 @@ def main():
"model_name_or_path",
help="Name of the model on HuggingFace hub, or path to the model.",
)
parser.add_argument(
"languages",
nargs="*",
help="List of language codes to evaluate the model on.",
)
parser.add_argument(
"--save_path",
"-o",
default="benchmark_results.json",
help="File to store benchmark results in.",
)

args = parser.parse_args()
logging.info(f"Running benchmark with {args.model_name_or_path}...")
results = run_benchmark(args.model_name_or_path)
if not args.languages:
args.languages = None
results = run_benchmark(args.model_name_or_path, args.languages)
logging.info("Saving results...")
save_path = Path(args.save_path)
with save_path.open("w") as save_file:
Expand Down

0 comments on commit 975418e

Please sign in to comment.