From 8eea3dd31dab49b4d9371f61f02ace9fdca25394 Mon Sep 17 00:00:00 2001 From: Alexander <47296670+Marsmaennchen221@users.noreply.github.com> Date: Fri, 28 Apr 2023 21:11:24 +0200 Subject: [PATCH] feat: Raise error if an untagged table is used instead of a `TaggedTable` (#234) Closes #192. ### Summary of Changes Added `UntaggedTableError` which is raised when an untagged table is used instead of a `TaggedTable` --------- Co-authored-by: sibre28 <86068340+sibre28@users.noreply.github.com> Co-authored-by: Lars Reimann Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com> --- src/safeds/ml/classical/_util_sklearn.py | 5 ++ .../classical/classification/_classifier.py | 12 +++- .../ml/classical/regression/_regressor.py | 20 +++++-- src/safeds/ml/exceptions/__init__.py | 2 + src/safeds/ml/exceptions/_exceptions.py | 12 ++++ .../classification/test_classifier.py | 35 ++++++++++++ .../ml/classical/regression/test_regressor.py | 56 ++++++++++++++++++- 7 files changed, 133 insertions(+), 9 deletions(-) diff --git a/src/safeds/ml/classical/_util_sklearn.py b/src/safeds/ml/classical/_util_sklearn.py index 145ee06f5..f286876c8 100644 --- a/src/safeds/ml/classical/_util_sklearn.py +++ b/src/safeds/ml/classical/_util_sklearn.py @@ -8,6 +8,7 @@ LearningError, ModelNotFittedError, PredictionError, + UntaggedTableError, ) @@ -27,7 +28,11 @@ def fit(model: Any, tagged_table: TaggedTable) -> None: ------ LearningError If the tagged table contains invalid values or if the training failed. + UntaggedTableError + If the table is untagged. """ + if not isinstance(tagged_table, TaggedTable) and isinstance(tagged_table, Table): + raise UntaggedTableError try: model.fit( tagged_table.features._data, diff --git a/src/safeds/ml/classical/classification/_classifier.py b/src/safeds/ml/classical/classification/_classifier.py index 3752c95a0..0b25121bc 100644 --- a/src/safeds/ml/classical/classification/_classifier.py +++ b/src/safeds/ml/classical/classification/_classifier.py @@ -1,12 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING from sklearn.metrics import accuracy_score as sk_accuracy_score -if TYPE_CHECKING: - from safeds.data.tabular.containers import Table, TaggedTable +from safeds.data.tabular.containers import Table, TaggedTable +from safeds.ml.exceptions import UntaggedTableError class Classifier(ABC): @@ -87,7 +86,14 @@ def accuracy(self, validation_or_test_set: TaggedTable) -> float: ------- accuracy : float The calculated accuracy score, i.e. the percentage of equal data. + + Raises + ------ + UntaggedTableError + If the table is untagged. """ + if not isinstance(validation_or_test_set, TaggedTable) and isinstance(validation_or_test_set, Table): + raise UntaggedTableError expected = validation_or_test_set.target predicted = self.predict(validation_or_test_set.features).target diff --git a/src/safeds/ml/classical/regression/_regressor.py b/src/safeds/ml/classical/regression/_regressor.py index 8bcb7b423..13685abf6 100644 --- a/src/safeds/ml/classical/regression/_regressor.py +++ b/src/safeds/ml/classical/regression/_regressor.py @@ -1,15 +1,13 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING from sklearn.metrics import mean_absolute_error as sk_mean_absolute_error from sklearn.metrics import mean_squared_error as sk_mean_squared_error +from safeds.data.tabular.containers import Column, Table, TaggedTable from safeds.data.tabular.exceptions import ColumnLengthMismatchError - -if TYPE_CHECKING: - from safeds.data.tabular.containers import Column, Table, TaggedTable +from safeds.ml.exceptions import UntaggedTableError class Regressor(ABC): @@ -90,7 +88,14 @@ def mean_squared_error(self, validation_or_test_set: TaggedTable) -> float: ------- mean_squared_error : float The calculated mean squared error (the average of the distance of each individual row squared). + + Raises + ------ + UntaggedTableError + If the table is untagged. """ + if not isinstance(validation_or_test_set, TaggedTable) and isinstance(validation_or_test_set, Table): + raise UntaggedTableError expected = validation_or_test_set.target predicted = self.predict(validation_or_test_set.features).target @@ -111,7 +116,14 @@ def mean_absolute_error(self, validation_or_test_set: TaggedTable) -> float: ------- mean_absolute_error : float The calculated mean absolute error (the average of the distance of each individual row). + + Raises + ------ + UntaggedTableError + If the table is untagged. """ + if not isinstance(validation_or_test_set, TaggedTable) and isinstance(validation_or_test_set, Table): + raise UntaggedTableError expected = validation_or_test_set.target predicted = self.predict(validation_or_test_set.features).target diff --git a/src/safeds/ml/exceptions/__init__.py b/src/safeds/ml/exceptions/__init__.py index c84e54ae2..5cde6adfd 100644 --- a/src/safeds/ml/exceptions/__init__.py +++ b/src/safeds/ml/exceptions/__init__.py @@ -6,6 +6,7 @@ LearningError, ModelNotFittedError, PredictionError, + UntaggedTableError, ) __all__ = [ @@ -14,4 +15,5 @@ "LearningError", "ModelNotFittedError", "PredictionError", + "UntaggedTableError", ] diff --git a/src/safeds/ml/exceptions/_exceptions.py b/src/safeds/ml/exceptions/_exceptions.py index c101b7e39..bf82d7b9c 100644 --- a/src/safeds/ml/exceptions/_exceptions.py +++ b/src/safeds/ml/exceptions/_exceptions.py @@ -59,3 +59,15 @@ class PredictionError(Exception): def __init__(self, reason: str): super().__init__(f"Error occurred while predicting: {reason}") + + +class UntaggedTableError(Exception): + """Raised when an untagged table is used instead of a TaggedTable in a regression or classification.""" + + def __init__(self) -> None: + super().__init__( + ( + "This method needs a tagged table.\nA tagged table is a table that additionally knows which columns are" + " features and which are the target to predict.\nUse Table.tag_column() to create a tagged table." + ), + ) diff --git a/tests/safeds/ml/classical/classification/test_classifier.py b/tests/safeds/ml/classical/classification/test_classifier.py index 1be9af161..c0ebb0df9 100644 --- a/tests/safeds/ml/classical/classification/test_classifier.py +++ b/tests/safeds/ml/classical/classification/test_classifier.py @@ -20,6 +20,7 @@ LearningError, ModelNotFittedError, PredictionError, + UntaggedTableError, ) if TYPE_CHECKING: @@ -93,6 +94,23 @@ def test_should_raise_on_invalid_data(self, classifier: Classifier, invalid_data with pytest.raises(LearningError): classifier.fit(invalid_data) + @pytest.mark.parametrize( + "table", + [ + Table.from_dict( + { + "a": [1.0, 0.0, 0.0, 0.0], + "b": [0.0, 1.0, 1.0, 0.0], + "c": [0.0, 0.0, 0.0, 1.0], + }, + ), + ], + ids=["untagged_table"], + ) + def test_should_raise_if_table_is_not_tagged(self, classifier: Classifier, table: Table) -> None: + with pytest.raises(UntaggedTableError): + classifier.fit(table) # type: ignore[arg-type] + @pytest.mark.parametrize("classifier", classifiers(), ids=lambda x: x.__class__.__name__) class TestPredict: @@ -200,3 +218,20 @@ def test_with_different_types(self) -> None: ).tag_columns(target_name="expected") assert DummyClassifier().accuracy(table) == 0.0 + + @pytest.mark.parametrize( + "table", + [ + Table.from_dict( + { + "a": [1.0, 0.0, 0.0, 0.0], + "b": [0.0, 1.0, 1.0, 0.0], + "c": [0.0, 0.0, 0.0, 1.0], + }, + ), + ], + ids=["untagged_table"], + ) + def test_should_raise_if_table_is_not_tagged(self, table: Table) -> None: + with pytest.raises(UntaggedTableError): + DummyClassifier().accuracy(table) # type: ignore[arg-type] diff --git a/tests/safeds/ml/classical/regression/test_regressor.py b/tests/safeds/ml/classical/regression/test_regressor.py index 029fd4fa5..6d1b24f24 100644 --- a/tests/safeds/ml/classical/regression/test_regressor.py +++ b/tests/safeds/ml/classical/regression/test_regressor.py @@ -28,6 +28,7 @@ LearningError, ModelNotFittedError, PredictionError, + UntaggedTableError, ) if TYPE_CHECKING: @@ -104,6 +105,23 @@ def test_should_raise_on_invalid_data(self, regressor: Regressor, invalid_data: with pytest.raises(LearningError): regressor.fit(invalid_data) + @pytest.mark.parametrize( + "table", + [ + Table.from_dict( + { + "a": [1.0, 0.0, 0.0, 0.0], + "b": [0.0, 1.0, 1.0, 0.0], + "c": [0.0, 0.0, 0.0, 1.0], + }, + ), + ], + ids=["untagged_table"], + ) + def test_should_raise_if_table_is_not_tagged(self, regressor: Regressor, table: Table) -> None: + with pytest.raises(UntaggedTableError): + regressor.fit(table) # type: ignore[arg-type] + @pytest.mark.parametrize("regressor", regressors(), ids=lambda x: x.__class__.__name__) class TestPredict: @@ -214,6 +232,23 @@ def test_valid_data(self, predicted: list[float], expected: list[float], result: assert DummyRegressor().mean_absolute_error(table) == result + @pytest.mark.parametrize( + "table", + [ + Table.from_dict( + { + "a": [1.0, 0.0, 0.0, 0.0], + "b": [0.0, 1.0, 1.0, 0.0], + "c": [0.0, 0.0, 0.0, 1.0], + }, + ), + ], + ids=["untagged_table"], + ) + def test_should_raise_if_table_is_not_tagged(self, table: Table) -> None: + with pytest.raises(UntaggedTableError): + DummyRegressor().mean_absolute_error(table) # type: ignore[arg-type] + class TestMeanSquaredError: @pytest.mark.parametrize( @@ -227,6 +262,23 @@ def test_valid_data(self, predicted: list[float], expected: list[float], result: assert DummyRegressor().mean_squared_error(table) == result + @pytest.mark.parametrize( + "table", + [ + Table.from_dict( + { + "a": [1.0, 0.0, 0.0, 0.0], + "b": [0.0, 1.0, 1.0, 0.0], + "c": [0.0, 0.0, 0.0, 1.0], + }, + ), + ], + ids=["untagged_table"], + ) + def test_should_raise_if_table_is_not_tagged(self, table: Table) -> None: + with pytest.raises(UntaggedTableError): + DummyRegressor().mean_squared_error(table) # type: ignore[arg-type] + class TestCheckMetricsPreconditions: @pytest.mark.parametrize( @@ -243,7 +295,7 @@ def test_should_raise_if_validation_fails( expected: list[str | int], error: type[Exception], ) -> None: - actual_column = Column("actual", pd.Series(actual)) - expected_column = Column("expected", pd.Series(expected)) + actual_column: Column = Column("actual", pd.Series(actual)) + expected_column: Column = Column("expected", pd.Series(expected)) with pytest.raises(error): _check_metrics_preconditions(actual_column, expected_column)