Skip to content

Commit

Permalink
Merge pull request #73 from KennethEnevoldsen/bug-scala-missing-task-…
Browse files Browse the repository at this point in the history
…encode-wrapper

Wraps ScaLA models in MTEBTaskModel
  • Loading branch information
KennethEnevoldsen authored Jan 22, 2024
2 parents b030aef + a70c950 commit e2eee05
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/seb/registered_tasks/multilingual.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datasets import DatasetDict, concatenate_datasets

from seb.interfaces.model import Encoder
from seb.interfaces.mteb_task import MTEBTask
from seb.interfaces.mteb_task import MTEBTask, MTEBTaskModel
from seb.interfaces.task import Task
from seb.registries import tasks
from seb.result_dataclasses import TaskResult
Expand Down Expand Up @@ -84,7 +84,7 @@ def get_descriptive_stats(self) -> dict[str, Any]:
for text_column in self._text_columns:
texts += ds[split][text_column]

document_lengths = [len(text) for text in texts]
document_lengths = np.array([len(text) for text in texts])

mean = np.mean(document_lengths)
std = np.std(document_lengths)
Expand All @@ -96,9 +96,10 @@ def get_descriptive_stats(self) -> dict[str, Any]:

def evaluate(self, model: Encoder) -> TaskResult:
scores = {}
_model = MTEBTaskModel(model, self)
for lang, mteb_task in self.mteb_tasks.items():
mteb_task.load_data()
score = mteb_task.evaluate(model)
score = mteb_task.evaluate(_model)
scores[lang] = score

return TaskResult(
Expand Down

0 comments on commit e2eee05

Please sign in to comment.