Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
KennethEnevoldsen committed Sep 25, 2023
2 parents aebaa4e + 0ef4b34 commit 2e7bb7e
Show file tree
Hide file tree
Showing 20 changed files with 210 additions and 128 deletions.
25 changes: 12 additions & 13 deletions docs/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
Script for running the benchmark and pushing the results to Datawrapper.
Example:
python run_benchmark.py --data-wrapper-api-token <token>
"""
import argparse
from typing import List

import numpy as np
import pandas as pd
from datawrapper import Datawrapper

import seb
from datawrapper import Datawrapper
from seb.full_benchmark import BENCHMARKS

subset_to_chart_id = {
Expand All @@ -30,35 +28,33 @@
}


def get_main_score(task: seb.TaskResult, langs: List[str]) -> float:
def get_main_score(task: seb.TaskResult, langs: list[str]) -> float:
_langs = set(langs) & set(task.languages)
return task.get_main_score(_langs) * 100


def create_mdl_name(mdl: seb.ModelMeta):
def create_mdl_name(mdl: seb.ModelMeta) -> str:
reference = mdl.reference
name = mdl.name

if reference:
mdl_name = f"[{name}]({reference})"
else:
mdl_name = name
mdl_name = f"[{name}]({reference})" if reference else name

if mdl.languages:
lang_code = " ".join(
[
f":{datawrapper_lang_codes[l]}:"
for l in mdl.languages
if l in datawrapper_lang_codes
]
],
)
mdl_name = f"{mdl_name} {lang_code}"

return mdl_name


def benchmark_result_to_row(
result: seb.BenchmarkResults, langs: List[str]
result: seb.BenchmarkResults,
langs: list[str],
) -> pd.DataFrame:
mdl_name = create_mdl_name(result.meta)
# sort by task name
Expand All @@ -73,7 +69,10 @@ def benchmark_result_to_row(
return df


def convert_to_table(results: List[seb.BenchmarkResults], langs: List[str]):
def convert_to_table(
results: list[seb.BenchmarkResults],
langs: list[str],
) -> pd.DataFrame:
rows = [benchmark_result_to_row(result, langs) for result in results]
df = pd.concat(rows)
df = df.sort_values(by="Average", ascending=False)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
"typer>=0.7.0",
"pydantic>=2.1.0",
"catalogue>=2.0.8",
"openai>=0.27.4",
]

[project.license]
Expand Down Expand Up @@ -105,7 +106,7 @@ select = [
"SIM",
"W",
]
ignore = ["ANN101", "ANN401", "E402", "E501", "F401", "F841", "RET504"]
ignore = ["ANN101", "ANN401", "E402", "E501", "E741", "F401", "F841", "RET504"]
ignore-init-module-imports = true
# Allow autofix for all enabled rules (when `--fix`) is provided.
unfixable = ["ERA"]
Expand Down
34 changes: 22 additions & 12 deletions src/seb/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Union
from typing import Optional, Union

from tqdm import tqdm

Expand Down Expand Up @@ -41,7 +41,9 @@ def run_task(
except Exception as e:
logger.error(f"Error when running {task.name} on {model.meta.name}: {e}")
return TaskError(
task_name=task.name, error=str(e), time_of_run=datetime.now()
task_name=task.name,
error=str(e),
time_of_run=datetime.now(),
)

cache_path = get_cache_path(task, model)
Expand All @@ -64,8 +66,8 @@ class Benchmark:

def __init__(
self,
languages: Optional[List[str]] = None,
tasks: Optional[List[str]] = None,
languages: Optional[list[str]] = None,
tasks: Optional[list[str]] = None,
) -> None:
"""
Initialize the benchmark.
Expand All @@ -78,7 +80,7 @@ def __init__(
self.tasks_names = tasks
self.tasks = self.get_tasks()

def get_tasks(self) -> List[Task]:
def get_tasks(self) -> list[Task]:
"""
Get the tasks for the benchmark.
Expand All @@ -88,9 +90,9 @@ def get_tasks(self) -> List[Task]:
tasks = []

if self.tasks_names is not None:
tasks: List[Task] = [get_task(task_name) for task_name in self.tasks_names]
tasks: list[Task] = [get_task(task_name) for task_name in self.tasks_names]
else:
tasks: List[Task] = get_all_tasks()
tasks: list[Task] = get_all_tasks()

if self.languages is not None:
langs = set(self.languages)
Expand All @@ -99,7 +101,10 @@ def get_tasks(self) -> List[Task]:
return tasks

def evaluate_model(
self, model: SebModel, use_cache: bool = True, raise_errors: bool = True
self,
model: SebModel,
use_cache: bool = True,
raise_errors: bool = True,
) -> BenchmarkResults:
"""
Evaluate a model on the benchmark.
Expand All @@ -122,8 +127,11 @@ def evaluate_model(
return BenchmarkResults(meta=model.meta, task_results=task_results)

def evaluate_models(
self, models: List[SebModel], use_cache: bool = True, raise_errors: bool = True
) -> List[BenchmarkResults]:
self,
models: list[SebModel],
use_cache: bool = True,
raise_errors: bool = True,
) -> list[BenchmarkResults]:
"""
Evaluate a list of models on the benchmark.
Expand All @@ -141,7 +149,9 @@ def evaluate_models(
for model in pbar:
results.append(
self.evaluate_model(
model, use_cache=use_cache, raise_errors=raise_errors
)
model,
use_cache=use_cache,
raise_errors=raise_errors,
),
)
return results
3 changes: 1 addition & 2 deletions src/seb/full_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
This is the specification for the full benchmark. Running the code here will reproduce the results.
"""

from typing import List

from seb.model_interface import SebModel

Expand All @@ -22,7 +21,7 @@ def run_benchmark(use_cache: bool = True, raise_errors: bool=True) -> dict[str,
"""
Run the full SEB benchmark.
"""
models: List[SebModel] = get_all_models()
models: list[SebModel] = get_all_models()

results = {}
for subset, langs in BENCHMARKS.items():
Expand Down
25 changes: 16 additions & 9 deletions src/seb/model_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Optional, Protocol, Union, runtime_checkable
from typing import Callable, Optional, Protocol, Union, runtime_checkable

from numpy import ndarray
from pydantic import BaseModel
Expand All @@ -16,12 +16,16 @@ class ModelInterface(Protocol):
"""

def encode(
self, sentences: List[str], batch_size: int = 32, **kwargs
) -> List[ArrayLike]:
self,
sentences: list[str],
batch_size: int = 32,
**kwargs: dict,
) -> list[ArrayLike]:
"""Returns a list of embeddings for the given sentences.
Args:
sentences: List of sentences to encode
batch_size: Batch size for the encoding
kwargs: arguments to pass to the models encode method
Returns:
List of embeddings for the given sentences
Expand All @@ -34,9 +38,9 @@ class ModelMeta(BaseModel):
description: Optional[str] = None
huggingface_name: Optional[str] = None
reference: Optional[str] = None
languages: List[str] = []
languages: list[str] = []

def get_path_name(self):
def get_path_name(self) -> str:
if self.huggingface_name is None:
return name_to_path(self.name)
return name_to_path(self.huggingface_name)
Expand Down Expand Up @@ -68,17 +72,20 @@ def number_of_parameters(self) -> Optional[int]:
"""
if hasattr(self.model, "num_parameters"):
return sum(p.numel() for p in self.model.parameters() if p.requires_grad) # type: ignore
else:
return None
return None

def encode(
self, sentences: List[str], batch_size: int = 32, **kwargs
) -> List[ArrayLike]:
self,
sentences: list[str],
batch_size: int = 32,
**kwargs: dict,
) -> list[ArrayLike]:
"""
Returns a list of embeddings for the given sentences.
Args:
sentences: List of sentences to encode
batch_size: Batch size for the encoding
kwargs: arguments to pass to the models encode method
Returns:
List of embeddings for the given sentences
Expand Down
8 changes: 4 additions & 4 deletions src/seb/mteb_tasks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Any, Dict
from typing import Any

import datasets
from mteb.abstasks.AbsTaskRetrieval import AbsTaskRetrieval


class SweFaqRetrieval(AbsTaskRetrieval):
@property
def description(self) -> Dict[str, Any]:
def description(self) -> dict[str, Any]:
return {
"name": "swefaq",
"hf_hub_name": "AI-Sweden/SuperLim",
Expand All @@ -20,7 +20,7 @@ def description(self) -> Dict[str, Any]:
"revision": "7ebf0b4caa7b2ae39698a889de782c09e6f5ee56",
}

def load_data(self, **kwargs):
def load_data(self, **kwargs: dict): # noqa: ARG002
"""
Load dataset from HuggingFace hub
"""
Expand Down Expand Up @@ -75,5 +75,5 @@ def dataset_transform(self) -> None:
cor_n = text2id[co]

self.relevant_docs[split][str(q_n)] = {
str(cor_n): 1
str(cor_n): 1,
} # only one correct match
6 changes: 2 additions & 4 deletions src/seb/registries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

import catalogue

from .model_interface import SebModel
Expand Down Expand Up @@ -35,7 +33,7 @@ def get_task(name: str) -> Task:
return tasks.get(name)()


def get_all_tasks() -> List[Task]:
def get_all_tasks() -> list[Task]:
"""
Returns all tasks implemented in SEB.
Expand All @@ -45,7 +43,7 @@ def get_all_tasks() -> List[Task]:
return [get_task(task_name) for task_name in tasks.get_all()]


def get_all_models() -> List[SebModel]:
def get_all_models() -> list[SebModel]:
"""
Get all the models implemented in SEB.
Expand Down
Loading

0 comments on commit 2e7bb7e

Please sign in to comment.