From bbaa1069669ebc7f704b063a27ca6dda4c8f7428 Mon Sep 17 00:00:00 2001 From: Kira McLean Date: Wed, 26 Jun 2024 00:13:58 -0300 Subject: [PATCH] test that all rows of a competition test set will have at least a value --- polaris/competition/_competition.py | 21 +++++++++- polaris/utils/errors.py | 2 + tests/test_competition.py | 65 ++++++++++++++++++++++++++++- 3 files changed, 86 insertions(+), 2 deletions(-) diff --git a/polaris/competition/_competition.py b/polaris/competition/_competition.py index 12498774..6f463997 100644 --- a/polaris/competition/_competition.py +++ b/polaris/competition/_competition.py @@ -2,10 +2,15 @@ import os from typing import Optional, Union -from pydantic import field_serializer +from pydantic import ( + field_serializer, + field_validator, + ValidationInfo +) from polaris.benchmark import BenchmarkSpecification from polaris.hub.settings import PolarisHubSettings from polaris.utils.types import AccessType, HubOwner, PredictionsType, TimeoutTypes, ZarrConflictResolution +from polaris.utils.errors import InvalidCompetitionError class CompetitionSpecification(BenchmarkSpecification): @@ -29,6 +34,20 @@ class CompetitionSpecification(BenchmarkSpecification): scheduled_end_time: datetime | None = None actual_end_time: datetime | None = None + @field_validator("split") + def _validate_test_set(cls, split, info: ValidationInfo): + """Verifies that the test does not have too many missing values. There must be at + least one value per row in the test set across the target columns.""" + dataset = info.data.get("dataset") + target_cols = info.data.get("target_cols") + test_indices = split[1] + + if dataset.table.loc[test_indices, target_cols].notna().any(axis=1).all(): + return split + else: + raise InvalidCompetitionError("All rows of the test set must have at least one value.") + + def evaluate( self, y_pred: PredictionsType, diff --git a/polaris/utils/errors.py b/polaris/utils/errors.py index 1cb46581..103c0d62 100644 --- a/polaris/utils/errors.py +++ b/polaris/utils/errors.py @@ -5,6 +5,8 @@ class InvalidDatasetError(ValueError): class InvalidBenchmarkError(ValueError): pass +class InvalidCompetitionError(ValueError): + pass class InvalidResultError(ValueError): pass diff --git a/tests/test_competition.py b/tests/test_competition.py index 1de4747f..db9b6c0e 100644 --- a/tests/test_competition.py +++ b/tests/test_competition.py @@ -1,8 +1,13 @@ import numpy as np import pandas as pd -from polaris.evaluate.utils import evaluate_benchmark +import pytest +from pydantic import ValidationError +from polaris.evaluate.utils import evaluate_benchmark from polaris.competition import CompetitionSpecification +from polaris.dataset import Dataset +from polaris.utils.types import HubOwner +from polaris.utils.errors import InvalidCompetitionError def test_competition_from_json(test_competition, tmpdir): """Test whether we can successfully save and load a competition from JSON.""" @@ -54,3 +59,61 @@ def test_single_col_competition_evaluation(test_competition): "Metric", "Score", } + +def test_invalid_competition_creation(): + data = {"col a": [1, 2, None], + "col b": [4, None, 6], + "col c": [7, 8, None]} + + df = pd.DataFrame(data, index=range(3)) + dataset = Dataset( + table=df, + name="test-dataset" + ) + + # Check that creating a competition where there is at least one value per test row works + CompetitionSpecification( + name="test-competition", + owner=HubOwner(organizationId="test-org", slug="test-org"), + dataset=dataset, + metrics=["mean_absolute_error", "mean_squared_error"], + split=([0, 1], [2]), + target_cols=["col b"], + input_cols=["col a", "col c"] + ) + + CompetitionSpecification( + name="test-competition", + owner=HubOwner(organizationId="test-org", slug="test-org"), + dataset=dataset, + metrics=["mean_absolute_error", "mean_squared_error"], + split=([0, 1], [2]), + target_cols=["col a", "col b"], + input_cols=["col c"] + ) + + with pytest.raises(ValidationError) as ex_info: + CompetitionSpecification( + name="test-competition", + owner=HubOwner(organizationId="test-org", slug="test-org"), + dataset=dataset, + metrics=["mean_absolute_error", "mean_squared_error"], + split=([0, 1], [2]), + target_cols=["col a"], + input_cols=["col b", "col c"] + ) + + assert ex_info.match("All rows of the test set must have at least one value") + + with pytest.raises(ValidationError) as ex_info_2: + CompetitionSpecification( + name="test-competition", + owner=HubOwner(organizationId="test-org", slug="test-org"), + dataset=dataset, + metrics=["mean_absolute_error", "mean_squared_error"], + split=([0, 1], [2]), + target_cols=["col a", "col c"], + input_cols=["col b"] + ) + + assert ex_info_2.match("All rows of the test set must have at least one value")