From 617d6f01c28723078e295f65dcf802bf8b3c64d0 Mon Sep 17 00:00:00 2001 From: a-szulc Date: Thu, 29 Aug 2024 15:03:04 +0200 Subject: [PATCH] add catching of expected errors in tests --- .../preprocessing/exclude_missing_target.py | 3 ++- supervised/validation/validator_kfold.py | 4 +++- supervised/validation/validator_split.py | 4 +++- tests/tests_automl/test_targets.py | 23 +++++++++++++++++-- .../tests_validation/test_validator_kfold.py | 11 ++++++++- .../tests_validation/test_validator_split.py | 11 ++++++++- 6 files changed, 49 insertions(+), 7 deletions(-) diff --git a/supervised/preprocessing/exclude_missing_target.py b/supervised/preprocessing/exclude_missing_target.py index f9e8c9593..28105218f 100644 --- a/supervised/preprocessing/exclude_missing_target.py +++ b/supervised/preprocessing/exclude_missing_target.py @@ -23,7 +23,8 @@ def transform( logger.debug("Exclude rows with missing target values") if warn: warnings.warn( - "There are samples with missing target values in the data which will be excluded for further analysis" + "There are samples with missing target values in the data which will be excluded for further analysis", + UserWarning ) y = y.drop(y.index[y_missing]) y.reset_index(drop=True, inplace=True) diff --git a/supervised/validation/validator_kfold.py b/supervised/validation/validator_kfold.py index 82104cb8a..71e8dbcad 100644 --- a/supervised/validation/validator_kfold.py +++ b/supervised/validation/validator_kfold.py @@ -25,7 +25,9 @@ def __init__(self, params): self.repeats = self.params.get("repeats", 1) if not self.shuffle and self.repeats > 1: - warnings.warn("Disable repeats in validation because shuffle is disabled") + warnings.warn( + "Disable repeats in validation because shuffle is disabled", UserWarning + ) self.repeats = 1 self.skf = [] diff --git a/supervised/validation/validator_split.py b/supervised/validation/validator_split.py index c0358f12a..da372badd 100644 --- a/supervised/validation/validator_split.py +++ b/supervised/validation/validator_split.py @@ -24,7 +24,9 @@ def __init__(self, params): self.repeats = self.params.get("repeats", 1) if not self.shuffle and self.repeats > 1: - warnings.warn("Disable repeats in validation because shuffle is disabled") + warnings.warn( + "Disable repeats in validation because shuffle is disabled", UserWarning + ) self.repeats = 1 self._results_path = self.params.get("results_path") diff --git a/tests/tests_automl/test_targets.py b/tests/tests_automl/test_targets.py index 1b0e545ac..1a2eb8aa8 100644 --- a/tests/tests_automl/test_targets.py +++ b/tests/tests_automl/test_targets.py @@ -1,5 +1,6 @@ import shutil import unittest +import pytest import numpy as np import pandas as pd @@ -100,7 +101,16 @@ def test_bin_class_AB_missing_targets(self): explain_level=0, start_random_models=1, ) - automl.fit(X, y) + + with pytest.warns( + expected_warning=UserWarning, + match="There are samples with missing target values in the data which will be excluded for further analysis", + ) as record: + automl.fit(X, y) + + # check that only one warning was raised + self.assertEqual(len(record), 1) + p = automl.predict(X) pred = automl.predict(X) @@ -256,7 +266,16 @@ def test_multi_class_abcd_missing_target(self): explain_level=0, start_random_models=1, ) - automl.fit(X, y) + + with pytest.warns( + expected_warning=UserWarning, + match="There are samples with missing target values in the data which will be excluded for further analysis", + ) as record: + automl.fit(X, y) + + # check that only one warning was raised + self.assertEqual(len(record), 1) + pred = automl.predict(X) u = np.unique(pred) diff --git a/tests/tests_validation/test_validator_kfold.py b/tests/tests_validation/test_validator_kfold.py index 42f4c5c43..9e87cd964 100644 --- a/tests/tests_validation/test_validator_kfold.py +++ b/tests/tests_validation/test_validator_kfold.py @@ -1,6 +1,7 @@ import os import tempfile import unittest +import pytest import numpy as np import pandas as pd @@ -194,7 +195,15 @@ def test_disable_repeats_when_disabled_shuffle(self): "y_path": y_path, "random_seed": 1, } - vl = KFoldValidator(params) + + with pytest.warns( + expected_warning=UserWarning, + match="Disable repeats in validation because shuffle is disabled", + ) as record: + vl = KFoldValidator(params) + + # check that only one warning was raised + self.assertEqual(len(record), 1) self.assertEqual(params["k_folds"], vl.get_n_splits()) self.assertEqual(1, vl.get_repeats()) diff --git a/tests/tests_validation/test_validator_split.py b/tests/tests_validation/test_validator_split.py index 1d7eed6fa..2dbcbe70d 100644 --- a/tests/tests_validation/test_validator_split.py +++ b/tests/tests_validation/test_validator_split.py @@ -1,6 +1,7 @@ import os import tempfile import unittest +import pytest import numpy as np import pandas as pd @@ -211,7 +212,15 @@ def test_disable_repeats_when_disabled_shuffle(self): "y_path": y_path, "repeats": 3, } - vl = SplitValidator(params) + + with pytest.warns( + expected_warning=UserWarning, + match="Disable repeats in validation because shuffle is disabled", + ) as record: + vl = SplitValidator(params) + + # check that only one warning was raised + self.assertEqual(len(record), 1) self.assertEqual(1, vl.get_n_splits()) self.assertEqual(1, vl.get_repeats())