Skip to content

Commit

Permalink
Allow an empty train partition to support test-set only (e.g. zero-sh…
Browse files Browse the repository at this point in the history
…ot) benchmarks (#135)

* Allow train partition to be empty in

* ruff

* log msg if empty train, add test case

* check return value of `get_train_test_split`

---------

Co-authored-by: Cas Wognum <caswognum@outlook.com>
  • Loading branch information
fteufel and cwognum authored Jul 21, 2024
1 parent 1278da0 commit 5cf6092
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
19 changes: 12 additions & 7 deletions polaris/benchmark/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pandas as pd
from datamol.utils import fs
from loguru import logger
from pydantic import (
Field,
ValidationInfo,
Expand Down Expand Up @@ -166,19 +167,23 @@ def _validate_main_metric(cls, v):
def _validate_split(cls, v, info: ValidationInfo):
"""
Verifies that:
1) There is at least two, non-empty partitions
1) There are no empty test partitions
2) All indices are valid given the dataset
3) There is no duplicate indices in any of the sets
3) There is no overlap between the train and test set
"""

# There is at least two, non-empty partitions
if (
len(v[0]) == 0
or (isinstance(v[1], dict) and any(len(v) == 0 for v in v[1].values()))
or (not isinstance(v[1], dict) and len(v[1]) == 0)
# Train partition can be empty (zero-shot)
# Test partitions cannot be empty
if (isinstance(v[1], dict) and any(len(v) == 0 for v in v[1].values())) or (
not isinstance(v[1], dict) and len(v[1]) == 0
):
raise InvalidBenchmarkError("The predefined split contains empty partitions")
raise InvalidBenchmarkError("The predefined split contains empty test partitions")

if len(v[0]) == 0:
logger.info(
"This benchmark only specifies a test set. It will return an empty train set in `get_train_test_split()`"
)

train_indices = v[0]
test_indices = [i for part in v[1].values() for i in part] if isinstance(v[1], dict) else v[1]
Expand Down
6 changes: 5 additions & 1 deletion tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_split_verification(is_single_task, test_single_task_benchmark, test_mul
train_split = obj.split[0]
test_split = obj.split[1]

# One or more empty partitions
# One or more empty test partitions
with pytest.raises(ValidationError):
cls(split=(train_split,), **default_kwargs)
with pytest.raises(ValidationError):
Expand Down Expand Up @@ -55,6 +55,10 @@ def test_split_verification(is_single_task, test_single_task_benchmark, test_mul
cls(split=(train_split, test_split + test_split[:1]), **default_kwargs)
# It should _not_ fail with missing indices
cls(split=(train_split[:-1], test_split), **default_kwargs)
# It should _not_ fail with an empty train set
benchmark = cls(split=([], test_split), **default_kwargs)
train, _ = benchmark.get_train_test_split()
assert len(train) == 0


@pytest.mark.parametrize("cls", [SingleTaskBenchmarkSpecification, MultiTaskBenchmarkSpecification])
Expand Down

0 comments on commit 5cf6092

Please sign in to comment.