Skip to content

Commit

Permalink
fix: limit the size of STS
Browse files Browse the repository at this point in the history
  • Loading branch information
KennethEnevoldsen committed Jan 25, 2024
1 parent ec412da commit 0d1b659
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/seb/registered_tasks/mteb_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def dataset_transform(self) -> None:
self.dataset = self.dataset.rename_column("summary", "sentence2")
self.dataset = self.dataset.rename_column("article", "sentence1")
self.dataset = self.dataset.remove_columns(["id", "headline", "article_category"])
random.seed(42)
self.dataset.shuffle(seed=42)

# add score column
for split in self.dataset:
Expand All @@ -143,10 +143,11 @@ def dataset_transform(self) -> None:
{
"sentence1": articles,
"sentence2": mismatched_summaries,
"score": [0] * len(articles),
"score": ([0] * len(articles)),
}
)
self.dataset[split] = datasets.concatenate_datasets([ds_split, mismatched_ds])
mismatched_ds.shuffle(seed=42)
self.dataset[split] = datasets.concatenate_datasets([ds_split.select(range(1024)), mismatched_ds.select(range(1024))])

@property
def description(self) -> dict[str, Any]:
Expand All @@ -171,9 +172,9 @@ def sattolo_cycle(items: list[T]) -> list[T]:
The Sattolo cycle is a simple algorithm for randomly shuffling an array in-place.
It ensures that the element i, will not be in the ith position of the result.
"""

rng = random.Random(42)
for i in range(len(items) - 1, 0, -1):
j = random.randint(0, i - 1)
j = rng.randint(0, i - 1)
items[i], items[j] = items[j], items[i]
return items

Expand Down

0 comments on commit 0d1b659

Please sign in to comment.