diff --git a/supervised/preprocessing/exclude_missing_target.py b/supervised/preprocessing/exclude_missing_target.py index f9e8c959..28105218 100644 --- a/supervised/preprocessing/exclude_missing_target.py +++ b/supervised/preprocessing/exclude_missing_target.py @@ -23,7 +23,8 @@ def transform( logger.debug("Exclude rows with missing target values") if warn: warnings.warn( - "There are samples with missing target values in the data which will be excluded for further analysis" + "There are samples with missing target values in the data which will be excluded for further analysis", + UserWarning ) y = y.drop(y.index[y_missing]) y.reset_index(drop=True, inplace=True) diff --git a/supervised/validation/validator_kfold.py b/supervised/validation/validator_kfold.py index 82104cb8..71e8dbca 100644 --- a/supervised/validation/validator_kfold.py +++ b/supervised/validation/validator_kfold.py @@ -25,7 +25,9 @@ def __init__(self, params): self.repeats = self.params.get("repeats", 1) if not self.shuffle and self.repeats > 1: - warnings.warn("Disable repeats in validation because shuffle is disabled") + warnings.warn( + "Disable repeats in validation because shuffle is disabled", UserWarning + ) self.repeats = 1 self.skf = [] diff --git a/supervised/validation/validator_split.py b/supervised/validation/validator_split.py index c0358f12..da372bad 100644 --- a/supervised/validation/validator_split.py +++ b/supervised/validation/validator_split.py @@ -24,7 +24,9 @@ def __init__(self, params): self.repeats = self.params.get("repeats", 1) if not self.shuffle and self.repeats > 1: - warnings.warn("Disable repeats in validation because shuffle is disabled") + warnings.warn( + "Disable repeats in validation because shuffle is disabled", UserWarning + ) self.repeats = 1 self._results_path = self.params.get("results_path") diff --git a/tests/tests_automl/test_targets.py b/tests/tests_automl/test_targets.py index 1b0e545a..1a2eb8aa 100644 --- a/tests/tests_automl/test_targets.py +++ b/tests/tests_automl/test_targets.py @@ -1,5 +1,6 @@ import shutil import unittest +import pytest import numpy as np import pandas as pd @@ -100,7 +101,16 @@ def test_bin_class_AB_missing_targets(self): explain_level=0, start_random_models=1, ) - automl.fit(X, y) + + with pytest.warns( + expected_warning=UserWarning, + match="There are samples with missing target values in the data which will be excluded for further analysis", + ) as record: + automl.fit(X, y) + + # check that only one warning was raised + self.assertEqual(len(record), 1) + p = automl.predict(X) pred = automl.predict(X) @@ -256,7 +266,16 @@ def test_multi_class_abcd_missing_target(self): explain_level=0, start_random_models=1, ) - automl.fit(X, y) + + with pytest.warns( + expected_warning=UserWarning, + match="There are samples with missing target values in the data which will be excluded for further analysis", + ) as record: + automl.fit(X, y) + + # check that only one warning was raised + self.assertEqual(len(record), 1) + pred = automl.predict(X) u = np.unique(pred) diff --git a/tests/tests_validation/test_validator_kfold.py b/tests/tests_validation/test_validator_kfold.py index 42f4c5c4..9e87cd96 100644 --- a/tests/tests_validation/test_validator_kfold.py +++ b/tests/tests_validation/test_validator_kfold.py @@ -1,6 +1,7 @@ import os import tempfile import unittest +import pytest import numpy as np import pandas as pd @@ -194,7 +195,15 @@ def test_disable_repeats_when_disabled_shuffle(self): "y_path": y_path, "random_seed": 1, } - vl = KFoldValidator(params) + + with pytest.warns( + expected_warning=UserWarning, + match="Disable repeats in validation because shuffle is disabled", + ) as record: + vl = KFoldValidator(params) + + # check that only one warning was raised + self.assertEqual(len(record), 1) self.assertEqual(params["k_folds"], vl.get_n_splits()) self.assertEqual(1, vl.get_repeats()) diff --git a/tests/tests_validation/test_validator_split.py b/tests/tests_validation/test_validator_split.py index 1d7eed6f..2dbcbe70 100644 --- a/tests/tests_validation/test_validator_split.py +++ b/tests/tests_validation/test_validator_split.py @@ -1,6 +1,7 @@ import os import tempfile import unittest +import pytest import numpy as np import pandas as pd @@ -211,7 +212,15 @@ def test_disable_repeats_when_disabled_shuffle(self): "y_path": y_path, "repeats": 3, } - vl = SplitValidator(params) + + with pytest.warns( + expected_warning=UserWarning, + match="Disable repeats in validation because shuffle is disabled", + ) as record: + vl = SplitValidator(params) + + # check that only one warning was raised + self.assertEqual(len(record), 1) self.assertEqual(1, vl.get_n_splits()) self.assertEqual(1, vl.get_repeats())