diff --git a/doc/whats_new/v0.10.rst b/doc/whats_new/v0.10.rst index 8a6f647f8..00809a1d4 100644 --- a/doc/whats_new/v0.10.rst +++ b/doc/whats_new/v0.10.rst @@ -19,6 +19,9 @@ Compatibility - Maintenance release for be compatible with scikit-learn >= 1.0.2. :pr:`946`, :pr:`947`, :pr:`949` by :user:`Guillaume Lemaitre `. +- Add support for automatic parameters validation as in scikit-learn >= 1.2. + :pr:`955` by :user:`Guillaume Lemaitre `. + Deprecation ........... diff --git a/imblearn/base.py b/imblearn/base.py index 0a4265e3f..4241d0db3 100644 --- a/imblearn/base.py +++ b/imblearn/base.py @@ -12,6 +12,7 @@ from sklearn.utils.multiclass import check_classification_targets from .utils import check_sampling_strategy, check_target_type +from .utils._param_validation import validate_parameter_constraints from .utils._validation import ArraysTransformer @@ -113,7 +114,26 @@ def _fit_resample(self, X, y): pass -class BaseSampler(SamplerMixin): +class _ParamsValidationMixin: + """Mixin class to validate parameters.""" + + def _validate_params(self): + """Validate types and values of constructor parameters. + + The expected type and values must be defined in the `_parameter_constraints` + class attribute, which is a dictionary `param_name: list of constraints`. See + the docstring of `validate_parameter_constraints` for a description of the + accepted constraints. + """ + if hasattr(self, "_parameter_constraints"): + validate_parameter_constraints( + self._parameter_constraints, + self.get_params(deep=False), + caller_name=self.__class__.__name__, + ) + + +class BaseSampler(SamplerMixin, _ParamsValidationMixin): """Base class for sampling algorithms. Warning: This class should not be used directly. Use the derive classes @@ -130,6 +150,52 @@ def _check_X_y(self, X, y, accept_sparse=None): X, y = self._validate_data(X, y, reset=True, accept_sparse=accept_sparse) return X, y, binarize_y + def fit(self, X, y): + """Check inputs and statistics of the sampler. + + You should use ``fit_resample`` in all cases. + + Parameters + ---------- + X : {array-like, dataframe, sparse matrix} of shape \ + (n_samples, n_features) + Data array. + + y : array-like of shape (n_samples,) + Target array. + + Returns + ------- + self : object + Return the instance itself. + """ + self._validate_params() + return super().fit(X, y) + + def fit_resample(self, X, y): + """Resample the dataset. + + Parameters + ---------- + X : {array-like, dataframe, sparse matrix} of shape \ + (n_samples, n_features) + Matrix containing the data which have to be sampled. + + y : array-like of shape (n_samples,) + Corresponding label for each sample in X. + + Returns + ------- + X_resampled : {array-like, dataframe, sparse matrix} of shape \ + (n_samples_new, n_features) + The array containing the resampled data. + + y_resampled : array-like of shape (n_samples_new,) + The corresponding label of `X_resampled`. + """ + self._validate_params() + return super().fit_resample(X, y) + def _more_tags(self): return {"X_types": ["2darray", "sparse", "dataframe"]} @@ -241,6 +307,13 @@ class FunctionSampler(BaseSampler): _sampling_type = "bypass" + _parameter_constraints: dict = { + "func": [callable, None], + "accept_sparse": ["boolean"], + "kw_args": [dict, None], + "validate": ["boolean"], + } + def __init__(self, *, func=None, accept_sparse=True, kw_args=None, validate=True): super().__init__() self.func = func @@ -267,6 +340,7 @@ def fit(self, X, y): self : object Return the instance itself. """ + self._validate_params() # we need to overwrite SamplerMixin.fit to bypass the validation if self.validate: check_classification_targets(y) @@ -298,6 +372,7 @@ def fit_resample(self, X, y): y_resampled : array-like of shape (n_samples_new,) The corresponding label of `X_resampled`. """ + self._validate_params() arrays_transformer = ArraysTransformer(X, y) if self.validate: diff --git a/imblearn/combine/_smote_enn.py b/imblearn/combine/_smote_enn.py index 333bf90b0..241fc0f70 100644 --- a/imblearn/combine/_smote_enn.py +++ b/imblearn/combine/_smote_enn.py @@ -4,6 +4,8 @@ # Christos Aridas # License: MIT +import numbers + from sklearn.base import clone from sklearn.utils import check_X_y @@ -102,6 +104,13 @@ class SMOTEENN(BaseSampler): _sampling_type = "over-sampling" + _parameter_constraints: dict = { + **BaseOverSampler._parameter_constraints, + "smote": [SMOTE, None], + "enn": [EditedNearestNeighbours, None], + "n_jobs": [numbers.Integral, None], + } + def __init__( self, *, @@ -121,14 +130,7 @@ def __init__( def _validate_estimator(self): "Private function to validate SMOTE and ENN objects" if self.smote is not None: - if isinstance(self.smote, SMOTE): - self.smote_ = clone(self.smote) - else: - raise ValueError( - f"smote needs to be a SMOTE object." - f"Got {type(self.smote)} instead." - ) - # Otherwise create a default SMOTE + self.smote_ = clone(self.smote) else: self.smote_ = SMOTE( sampling_strategy=self.sampling_strategy, @@ -137,14 +139,7 @@ def _validate_estimator(self): ) if self.enn is not None: - if isinstance(self.enn, EditedNearestNeighbours): - self.enn_ = clone(self.enn) - else: - raise ValueError( - f"enn needs to be an EditedNearestNeighbours." - f" Got {type(self.enn)} instead." - ) - # Otherwise create a default EditedNearestNeighbours + self.enn_ = clone(self.enn) else: self.enn_ = EditedNearestNeighbours( sampling_strategy="all", n_jobs=self.n_jobs diff --git a/imblearn/combine/_smote_tomek.py b/imblearn/combine/_smote_tomek.py index 495a4b246..9a4bc13e6 100644 --- a/imblearn/combine/_smote_tomek.py +++ b/imblearn/combine/_smote_tomek.py @@ -5,6 +5,8 @@ # Christos Aridas # License: MIT +import numbers + from sklearn.base import clone from sklearn.utils import check_X_y @@ -100,6 +102,13 @@ class SMOTETomek(BaseSampler): _sampling_type = "over-sampling" + _parameter_constraints: dict = { + **BaseOverSampler._parameter_constraints, + "smote": [SMOTE, None], + "tomek": [TomekLinks, None], + "n_jobs": [numbers.Integral, None], + } + def __init__( self, *, @@ -120,14 +129,7 @@ def _validate_estimator(self): "Private function to validate SMOTE and ENN objects" if self.smote is not None: - if isinstance(self.smote, SMOTE): - self.smote_ = clone(self.smote) - else: - raise ValueError( - f"smote needs to be a SMOTE object." - f"Got {type(self.smote)} instead." - ) - # Otherwise create a default SMOTE + self.smote_ = clone(self.smote) else: self.smote_ = SMOTE( sampling_strategy=self.sampling_strategy, @@ -136,14 +138,7 @@ def _validate_estimator(self): ) if self.tomek is not None: - if isinstance(self.tomek, TomekLinks): - self.tomek_ = clone(self.tomek) - else: - raise ValueError( - f"tomek needs to be a TomekLinks object." - f"Got {type(self.tomek)} instead." - ) - # Otherwise create a default TomekLinks + self.tomek_ = clone(self.tomek) else: self.tomek_ = TomekLinks(sampling_strategy="all", n_jobs=self.n_jobs) diff --git a/imblearn/combine/tests/test_smote_enn.py b/imblearn/combine/tests/test_smote_enn.py index 97cb6bead..df72cc749 100644 --- a/imblearn/combine/tests/test_smote_enn.py +++ b/imblearn/combine/tests/test_smote_enn.py @@ -4,7 +4,6 @@ # License: MIT import numpy as np -import pytest from sklearn.utils._testing import assert_allclose, assert_array_equal from imblearn.combine import SMOTEENN @@ -156,16 +155,3 @@ def test_parallelisation(): assert smt.n_jobs == 8 assert smt.smote_.n_jobs == 8 assert smt.enn_.n_jobs == 8 - - -@pytest.mark.parametrize( - "smote_params, err_msg", - [ - ({"smote": "rnd"}, "smote needs to be a SMOTE"), - ({"enn": "rnd"}, "enn needs to be an "), - ], -) -def test_error_wrong_object(smote_params, err_msg): - smt = SMOTEENN(**smote_params) - with pytest.raises(ValueError, match=err_msg): - smt.fit_resample(X, Y) diff --git a/imblearn/combine/tests/test_smote_tomek.py b/imblearn/combine/tests/test_smote_tomek.py index ca3ce98b6..2ca3e38c2 100644 --- a/imblearn/combine/tests/test_smote_tomek.py +++ b/imblearn/combine/tests/test_smote_tomek.py @@ -4,7 +4,6 @@ # License: MIT import numpy as np -import pytest from sklearn.utils._testing import assert_allclose, assert_array_equal from imblearn.combine import SMOTETomek @@ -166,16 +165,3 @@ def test_parallelisation(): assert smt.n_jobs == 8 assert smt.smote_.n_jobs == 8 assert smt.tomek_.n_jobs == 8 - - -@pytest.mark.parametrize( - "smote_params, err_msg", - [ - ({"smote": "rnd"}, "smote needs to be a SMOTE"), - ({"tomek": "rnd"}, "tomek needs to be a TomekLinks"), - ], -) -def test_error_wrong_object(smote_params, err_msg): - smt = SMOTETomek(**smote_params) - with pytest.raises(ValueError, match=err_msg): - smt.fit_resample(X, Y) diff --git a/imblearn/ensemble/_bagging.py b/imblearn/ensemble/_bagging.py index 13707abfb..e35343f27 100644 --- a/imblearn/ensemble/_bagging.py +++ b/imblearn/ensemble/_bagging.py @@ -4,6 +4,7 @@ # Christos Aridas # License: MIT +import copy import inspect import numbers import warnings @@ -18,13 +19,15 @@ from sklearn.utils.fixes import delayed from sklearn.utils.validation import check_is_fitted +from ..base import _ParamsValidationMixin from ..pipeline import Pipeline from ..under_sampling import RandomUnderSampler from ..under_sampling.base import BaseUnderSampler from ..utils import Substitution, check_sampling_strategy, check_target_type from ..utils._available_if import available_if from ..utils._docstring import _n_jobs_docstring, _random_state_docstring -from ._common import _estimator_has +from ..utils._param_validation import HasMethods, Interval, StrOptions +from ._common import _bagging_parameter_constraints, _estimator_has @Substitution( @@ -32,7 +35,7 @@ n_jobs=_n_jobs_docstring, random_state=_random_state_docstring, ) -class BalancedBaggingClassifier(BaggingClassifier): +class BalancedBaggingClassifier(BaggingClassifier, _ParamsValidationMixin): """A Bagging classifier with additional balancing. This implementation of Bagging is similar to the scikit-learn @@ -252,6 +255,26 @@ class BalancedBaggingClassifier(BaggingClassifier): [ 2 225]] """ + # make a deepcopy to not modify the original dictionary + if hasattr(BaggingClassifier, "_parameter_constraints"): + # scikit-learn >= 1.2 + _parameter_constraints = copy.deepcopy(BaggingClassifier._parameter_constraints) + else: + _parameter_constraints = copy.deepcopy(_bagging_parameter_constraints) + + _parameter_constraints.update( + { + "sampling_strategy": [ + Interval(numbers.Real, 0, 1, closed="right"), + StrOptions({"auto", "majority", "not minority", "not majority", "all"}), + dict, + callable, + ], + "replacement": ["boolean"], + "sampler": [HasMethods(["fit_resample"]), None], + } + ) + def __init__( self, estimator=None, @@ -316,17 +339,7 @@ def _validate_y(self, y): def _validate_estimator(self, default=DecisionTreeClassifier()): """Check the estimator and the n_estimator attribute, set the - `base_estimator_` attribute.""" - if not isinstance(self.n_estimators, (numbers.Integral, np.integer)): - raise ValueError( - f"n_estimators must be an integer, " f"got {type(self.n_estimators)}." - ) - - if self.n_estimators <= 0: - raise ValueError( - f"n_estimators must be greater than zero, " f"got {self.n_estimators}." - ) - + `estimator_` attribute.""" if self.estimator is not None and ( self.base_estimator not in [None, "deprecated"] ): @@ -395,6 +408,7 @@ def fit(self, X, y): Fitted estimator. """ # overwrite the base class method by disallowing `sample_weight` + self._validate_params() return super().fit(X, y) def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None): diff --git a/imblearn/ensemble/_common.py b/imblearn/ensemble/_common.py index eb24e737d..32f5fb1cc 100644 --- a/imblearn/ensemble/_common.py +++ b/imblearn/ensemble/_common.py @@ -1,3 +1,10 @@ +from numbers import Integral, Real + +from sklearn.tree._criterion import Criterion + +from ..utils._param_validation import HasMethods, Hidden, Interval, StrOptions + + def _estimator_has(attr): """Check if we can delegate a method to the underlying estimator. First, we check the first fitted estimator if available, otherwise we @@ -13,3 +20,79 @@ def check(self): return hasattr(self.base_estimator, attr) return check + + +_bagging_parameter_constraints = { + "estimator": [HasMethods(["fit", "predict"]), None], + "n_estimators": [Interval(Integral, 1, None, closed="left")], + "max_samples": [ + Interval(Integral, 1, None, closed="left"), + Interval(Real, 0, 1, closed="right"), + ], + "max_features": [ + Interval(Integral, 1, None, closed="left"), + Interval(Real, 0, 1, closed="right"), + ], + "bootstrap": ["boolean"], + "bootstrap_features": ["boolean"], + "oob_score": ["boolean"], + "warm_start": ["boolean"], + "n_jobs": [None, Integral], + "random_state": ["random_state"], + "verbose": ["verbose"], + "base_estimator": [ + HasMethods(["fit", "predict"]), + StrOptions({"deprecated"}), + None, + ], +} + +_adaboost_classifier_parameter_constraints = { + "estimator": [HasMethods(["fit", "predict"]), None], + "n_estimators": [Interval(Integral, 1, None, closed="left")], + "learning_rate": [Interval(Real, 0, None, closed="neither")], + "random_state": ["random_state"], + "base_estimator": [HasMethods(["fit", "predict"]), StrOptions({"deprecated"})], + "algorithm": [StrOptions({"SAMME", "SAMME.R"})], +} + +_random_forest_classifier_parameter_constraints = { + "n_estimators": [Interval(Integral, 1, None, closed="left")], + "bootstrap": ["boolean"], + "oob_score": ["boolean"], + "n_jobs": [Integral, None], + "random_state": ["random_state"], + "verbose": ["verbose"], + "warm_start": ["boolean"], + "criterion": [StrOptions({"gini", "entropy", "log_loss"}), Hidden(Criterion)], + "max_samples": [ + None, + Interval(Real, 0.0, 1.0, closed="right"), + Interval(Integral, 1, None, closed="left"), + ], + "max_depth": [Interval(Integral, 1, None, closed="left"), None], + "min_samples_split": [ + Interval(Integral, 2, None, closed="left"), + Interval(Real, 0.0, 1.0, closed="right"), + ], + "min_samples_leaf": [ + Interval(Integral, 1, None, closed="left"), + Interval(Real, 0.0, 1.0, closed="neither"), + ], + "min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")], + "max_features": [ + Interval(Integral, 1, None, closed="left"), + Interval(Real, 0.0, 1.0, closed="right"), + StrOptions({"auto", "sqrt", "log2"}, deprecated={"auto"}), + None, + ], + "max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None], + "min_impurity_decrease": [Interval(Real, 0.0, None, closed="left")], + "ccp_alpha": [Interval(Real, 0.0, None, closed="left")], + "class_weight": [ + StrOptions({"balanced_subsample", "balanced"}), + dict, + list, + None, + ], +} diff --git a/imblearn/ensemble/_easy_ensemble.py b/imblearn/ensemble/_easy_ensemble.py index 9ef457f82..2ec31e55d 100644 --- a/imblearn/ensemble/_easy_ensemble.py +++ b/imblearn/ensemble/_easy_ensemble.py @@ -4,6 +4,7 @@ # Christos Aridas # License: MIT +import copy import inspect import numbers import warnings @@ -17,13 +18,15 @@ from sklearn.utils.fixes import delayed from sklearn.utils.validation import check_is_fitted +from ..base import _ParamsValidationMixin from ..pipeline import Pipeline from ..under_sampling import RandomUnderSampler from ..under_sampling.base import BaseUnderSampler from ..utils import Substitution, check_sampling_strategy, check_target_type from ..utils._available_if import available_if from ..utils._docstring import _n_jobs_docstring, _random_state_docstring -from ._common import _estimator_has +from ..utils._param_validation import Interval, StrOptions +from ._common import _bagging_parameter_constraints, _estimator_has MAX_INT = np.iinfo(np.int32).max @@ -33,7 +36,7 @@ n_jobs=_n_jobs_docstring, random_state=_random_state_docstring, ) -class EasyEnsembleClassifier(BaggingClassifier): +class EasyEnsembleClassifier(BaggingClassifier, _ParamsValidationMixin): """Bag of balanced boosted learners also known as EasyEnsemble. This algorithm is known as EasyEnsemble [1]_. The classifier is an @@ -177,6 +180,35 @@ class EasyEnsembleClassifier(BaggingClassifier): [ 2 225]] """ + # make a deepcopy to not modify the original dictionary + if hasattr(BaggingClassifier, "_parameter_constraints"): + # scikit-learn >= 1.2 + _parameter_constraints = copy.deepcopy(BaggingClassifier._parameter_constraints) + else: + _parameter_constraints = copy.deepcopy(_bagging_parameter_constraints) + + excluded_params = { + "bootstrap", + "bootstrap_features", + "max_features", + "oob_score", + "max_samples", + } + for param in excluded_params: + _parameter_constraints.pop(param, None) + + _parameter_constraints.update( + { + "sampling_strategy": [ + Interval(numbers.Real, 0, 1, closed="right"), + StrOptions({"auto", "majority", "not minority", "not majority", "all"}), + dict, + callable, + ], + "replacement": ["boolean"], + } + ) + def __init__( self, n_estimators=10, @@ -231,17 +263,7 @@ def _validate_y(self, y): def _validate_estimator(self, default=AdaBoostClassifier()): """Check the estimator and the n_estimator attribute, set the - `base_estimator_` attribute.""" - if not isinstance(self.n_estimators, (numbers.Integral, np.integer)): - raise ValueError( - f"n_estimators must be an integer, " f"got {type(self.n_estimators)}." - ) - - if self.n_estimators <= 0: - raise ValueError( - f"n_estimators must be greater than zero, " f"got {self.n_estimators}." - ) - + `estimator_` attribute.""" if self.estimator is not None and ( self.base_estimator not in [None, "deprecated"] ): @@ -310,6 +332,7 @@ def fit(self, X, y): self : object Fitted estimator. """ + self._validate_params() # overwrite the base class method by disallowing `sample_weight` return super().fit(X, y) diff --git a/imblearn/ensemble/_forest.py b/imblearn/ensemble/_forest.py index c05511328..4714b703a 100644 --- a/imblearn/ensemble/_forest.py +++ b/imblearn/ensemble/_forest.py @@ -28,12 +28,15 @@ from sklearn.utils.multiclass import type_of_target from sklearn.utils.validation import _check_sample_weight +from ..base import _ParamsValidationMixin from ..pipeline import make_pipeline from ..under_sampling import RandomUnderSampler from ..under_sampling.base import BaseUnderSampler from ..utils import Substitution from ..utils._docstring import _n_jobs_docstring, _random_state_docstring +from ..utils._param_validation import Interval, StrOptions from ..utils._validation import check_sampling_strategy +from ._common import _random_forest_classifier_parameter_constraints MAX_INT = np.iinfo(np.int32).max sklearn_version = parse_version(sklearn.__version__) @@ -95,7 +98,7 @@ def _local_parallel_build_trees( n_jobs=_n_jobs_docstring, random_state=_random_state_docstring, ) -class BalancedRandomForestClassifier(RandomForestClassifier): +class BalancedRandomForestClassifier(RandomForestClassifier, _ParamsValidationMixin): """A balanced random forest classifier. A balanced random forest randomly under-samples each boostrap sample to @@ -352,6 +355,27 @@ class labels (multi-output problem). [1] """ + # make a deepcopy to not modify the original dictionary + if hasattr(RandomForestClassifier, "_parameter_constraints"): + # scikit-learn >= 1.2 + _parameter_constraints = deepcopy(RandomForestClassifier._parameter_constraints) + else: + _parameter_constraints = deepcopy( + _random_forest_classifier_parameter_constraints + ) + + _parameter_constraints.update( + { + "sampling_strategy": [ + Interval(numbers.Real, 0, 1, closed="right"), + StrOptions({"auto", "majority", "not minority", "not majority", "all"}), + dict, + callable, + ], + "replacement": ["boolean"], + } + ) + def __init__( self, n_estimators=100, @@ -402,17 +426,7 @@ def __init__( def _validate_estimator(self, default=DecisionTreeClassifier()): """Check the estimator and the n_estimator attribute, set the - `base_estimator_` attribute.""" - if not isinstance(self.n_estimators, (numbers.Integral, np.integer)): - raise ValueError( - f"n_estimators must be an integer, " f"got {type(self.n_estimators)}." - ) - - if self.n_estimators <= 0: - raise ValueError( - f"n_estimators must be greater than zero, " f"got {self.n_estimators}." - ) - + `estimator_` attribute.""" if hasattr(self, "estimator"): base_estimator = self.estimator else: @@ -475,7 +489,7 @@ def fit(self, X, y, sample_weight=None): self : object The fitted instance. """ - + self._validate_params() # Validate or convert input data if issparse(y): raise ValueError("sparse multilabel-indicator for y is not supported.") diff --git a/imblearn/ensemble/_weight_boosting.py b/imblearn/ensemble/_weight_boosting.py index 4e0d1e5c5..7ebc4ae7c 100644 --- a/imblearn/ensemble/_weight_boosting.py +++ b/imblearn/ensemble/_weight_boosting.py @@ -1,3 +1,4 @@ +import copy import inspect import numbers import warnings @@ -11,18 +12,21 @@ from sklearn.utils import _safe_indexing from sklearn.utils.validation import has_fit_parameter +from ..base import _ParamsValidationMixin from ..pipeline import make_pipeline from ..under_sampling import RandomUnderSampler from ..under_sampling.base import BaseUnderSampler from ..utils import Substitution, check_target_type from ..utils._docstring import _random_state_docstring +from ..utils._param_validation import Interval, StrOptions +from ._common import _adaboost_classifier_parameter_constraints @Substitution( sampling_strategy=BaseUnderSampler._sampling_strategy_docstring, random_state=_random_state_docstring, ) -class RUSBoostClassifier(AdaBoostClassifier): +class RUSBoostClassifier(AdaBoostClassifier, _ParamsValidationMixin): """Random under-sampling integrated in the learning of AdaBoost. During learning, the problem of class balancing is alleviated by random @@ -163,6 +167,29 @@ class RUSBoostClassifier(AdaBoostClassifier): array([...]) """ + # make a deepcopy to not modify the original dictionary + if hasattr(AdaBoostClassifier, "_parameter_constraints"): + # scikit-learn >= 1.2 + _parameter_constraints = copy.deepcopy( + AdaBoostClassifier._parameter_constraints + ) + else: + _parameter_constraints = copy.deepcopy( + _adaboost_classifier_parameter_constraints + ) + + _parameter_constraints.update( + { + "sampling_strategy": [ + Interval(numbers.Real, 0, 1, closed="right"), + StrOptions({"auto", "majority", "not minority", "not majority", "all"}), + dict, + callable, + ], + "replacement": ["boolean"], + } + ) + def __init__( self, estimator=None, @@ -214,6 +241,7 @@ def fit(self, X, y, sample_weight=None): self : object Returns self. """ + self._validate_params() check_target_type(y) self.samplers_ = [] self.pipelines_ = [] @@ -225,20 +253,6 @@ def _validate_estimator(self): Sets the `estimator_` attributes. """ - if not isinstance(self.n_estimators, numbers.Integral): - raise ValueError( - "n_estimators must be an integer, got {0}.".format( - type(self.n_estimators) - ) - ) - - if self.n_estimators <= 0: - raise ValueError( - "n_estimators must be greater than zero, got {0}.".format( - self.n_estimators - ) - ) - if self.estimator is not None and ( self.base_estimator not in [None, "deprecated"] ): diff --git a/imblearn/ensemble/tests/test_bagging.py b/imblearn/ensemble/tests/test_bagging.py index 546ccde17..d1c22bfb3 100644 --- a/imblearn/ensemble/tests/test_bagging.py +++ b/imblearn/ensemble/tests/test_bagging.py @@ -252,37 +252,6 @@ def test_single_estimator(): assert_array_equal(clf1.predict(X_test), clf2.predict(X_test)) -@pytest.mark.parametrize( - "params", - [ - {"n_estimators": 1.5}, - {"n_estimators": -1}, - {"max_samples": -1}, - {"max_samples": 0.0}, - {"max_samples": 2.0}, - {"max_samples": 1000}, - {"max_samples": "foobar"}, - {"max_features": -1}, - {"max_features": 0.0}, - {"max_features": 2.0}, - {"max_features": 5}, - {"max_features": "foobar"}, - ], -) -def test_balanced_bagging_classifier_error(params): - # Test that it gives proper exception on deficient input. - X, y = make_imbalance( - iris.data, iris.target, sampling_strategy={0: 20, 1: 25, 2: 50} - ) - base = DecisionTreeClassifier() - clf = BalancedBaggingClassifier(estimator=base, **params) - with pytest.raises(ValueError): - clf.fit(X, y) - - # Test support of decision_function - assert not (hasattr(BalancedBaggingClassifier(base).fit(X, y), "decision_function")) - - def test_gridsearch(): # Check that bagging ensembles can be grid-searched. # Transform iris into a binary classification task diff --git a/imblearn/ensemble/tests/test_easy_ensemble.py b/imblearn/ensemble/tests/test_easy_ensemble.py index 34bf72208..a8574e4d6 100644 --- a/imblearn/ensemble/tests/test_easy_ensemble.py +++ b/imblearn/ensemble/tests/test_easy_ensemble.py @@ -184,19 +184,6 @@ def test_warm_start_equivalence(): assert_allclose(y1, y2) -@pytest.mark.parametrize("n_estimators", [1.0, -10]) -def test_easy_ensemble_classifier_error(n_estimators): - X, y = make_imbalance( - iris.data, - iris.target, - sampling_strategy={0: 20, 1: 25, 2: 50}, - random_state=0, - ) - with pytest.raises(ValueError): - eec = EasyEnsembleClassifier(n_estimators=n_estimators) - eec.fit(X, y) - - def test_easy_ensemble_classifier_single_estimator(): X, y = make_imbalance( iris.data, diff --git a/imblearn/ensemble/tests/test_forest.py b/imblearn/ensemble/tests/test_forest.py index 8a5077fa8..c7ae65f85 100644 --- a/imblearn/ensemble/tests/test_forest.py +++ b/imblearn/ensemble/tests/test_forest.py @@ -27,23 +27,6 @@ def imbalanced_dataset(): ) -@pytest.mark.parametrize( - "forest_params, err_msg", - [ - ({"n_estimators": "whatever"}, "n_estimators must be an integer"), - ({"n_estimators": -100}, "n_estimators must be greater than zero"), - ( - {"bootstrap": False, "oob_score": True}, - "Out of bag estimation only", - ), - ], -) -def test_balanced_random_forest_error(imbalanced_dataset, forest_params, err_msg): - brf = BalancedRandomForestClassifier(**forest_params) - with pytest.raises(ValueError, match=err_msg): - brf.fit(*imbalanced_dataset) - - def test_balanced_random_forest_error_warning_warm_start(imbalanced_dataset): brf = BalancedRandomForestClassifier(n_estimators=5) brf.fit(*imbalanced_dataset) diff --git a/imblearn/ensemble/tests/test_weight_boosting.py b/imblearn/ensemble/tests/test_weight_boosting.py index 93ac8e13d..a36395e55 100644 --- a/imblearn/ensemble/tests/test_weight_boosting.py +++ b/imblearn/ensemble/tests/test_weight_boosting.py @@ -28,15 +28,6 @@ def imbalanced_dataset(): ) -@pytest.mark.parametrize( - "boosting_params", [{"n_estimators": "whatever"}, {"n_estimators": -100}] -) -def test_rusboost_error(imbalanced_dataset, boosting_params): - rusboost = RUSBoostClassifier(**boosting_params) - with pytest.raises((ValueError, TypeError)): - rusboost.fit(*imbalanced_dataset) - - @pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"]) def test_rusboost(imbalanced_dataset, algorithm): X, y = imbalanced_dataset diff --git a/imblearn/metrics/pairwise.py b/imblearn/metrics/pairwise.py index 2fa784fa2..ceec92802 100644 --- a/imblearn/metrics/pairwise.py +++ b/imblearn/metrics/pairwise.py @@ -3,6 +3,8 @@ # Authors: Guillaume Lemaitre # License: MIT +import numbers + import numpy as np from scipy.spatial import distance_matrix from sklearn.base import BaseEstimator @@ -10,8 +12,11 @@ from sklearn.utils.multiclass import unique_labels from sklearn.utils.validation import check_is_fitted +from ..base import _ParamsValidationMixin +from ..utils._param_validation import StrOptions + -class ValueDifferenceMetric(BaseEstimator): +class ValueDifferenceMetric(BaseEstimator, _ParamsValidationMixin): r"""Class implementing the Value Difference Metric. This metric computes the distance between samples containing only @@ -102,6 +107,11 @@ class ValueDifferenceMetric(BaseEstimator): [0.04, 0. , 1.44], [1.96, 1.44, 0. ]]) """ + _parameter_constraints: dict = { + "n_categories": [StrOptions({"auto"}), "array-like"], + "k": [numbers.Integral], + "r": [numbers.Integral], + } def __init__(self, *, n_categories="auto", k=1, r=2): self.n_categories = n_categories @@ -125,6 +135,7 @@ def fit(self, X, y): self : object Return the instance itself. """ + self._validate_params() check_consistent_length(X, y) X, y = self._validate_data(X, y, reset=True, dtype=np.int32) diff --git a/imblearn/over_sampling/_adasyn.py b/imblearn/over_sampling/_adasyn.py index 7417bc655..6f4b81fd5 100644 --- a/imblearn/over_sampling/_adasyn.py +++ b/imblearn/over_sampling/_adasyn.py @@ -4,6 +4,7 @@ # Christos Aridas # License: MIT +import numbers import warnings import numpy as np @@ -12,6 +13,7 @@ from ..utils import Substitution, check_neighbors_object from ..utils._docstring import _n_jobs_docstring, _random_state_docstring +from ..utils._param_validation import HasMethods, Interval from .base import BaseOverSampler @@ -114,6 +116,15 @@ class ADASYN(BaseOverSampler): Resampled dataset shape Counter({{0: 904, 1: 900}}) """ + _parameter_constraints: dict = { + **BaseOverSampler._parameter_constraints, + "n_neighbors": [ + Interval(numbers.Integral, 1, None, closed="left"), + HasMethods(["kneighbors", "kneighbors_graph"]), + ], + "n_jobs": [numbers.Integral, None], + } + def __init__( self, *, diff --git a/imblearn/over_sampling/_random_over_sampler.py b/imblearn/over_sampling/_random_over_sampler.py index 7093ac88e..1f12619dd 100644 --- a/imblearn/over_sampling/_random_over_sampler.py +++ b/imblearn/over_sampling/_random_over_sampler.py @@ -14,6 +14,7 @@ from ..utils import Substitution, check_target_type from ..utils._docstring import _random_state_docstring +from ..utils._param_validation import Interval from .base import BaseOverSampler @@ -129,6 +130,11 @@ class RandomOverSampler(BaseOverSampler): Resampled dataset shape Counter({{0: 900, 1: 900}}) """ + _parameter_constraints: dict = { + **BaseOverSampler._parameter_constraints, + "shrinkage": [Interval(Real, 0, None, closed="left"), dict, None], + } + def __init__( self, *, @@ -161,12 +167,6 @@ def _fit_resample(self, X, y): } elif self.shrinkage is None or isinstance(self.shrinkage, Mapping): self.shrinkage_ = self.shrinkage - else: - raise ValueError( - f"`shrinkage` should either be a positive floating number or " - f"a dictionary mapping a class to a positive floating number. " - f"Got {repr(self.shrinkage)} instead." - ) if self.shrinkage_ is not None: missing_shrinkage_keys = ( diff --git a/imblearn/over_sampling/_smote/base.py b/imblearn/over_sampling/_smote/base.py index c18515072..15b4664c8 100644 --- a/imblearn/over_sampling/_smote/base.py +++ b/imblearn/over_sampling/_smote/base.py @@ -7,6 +7,7 @@ # License: MIT import math +import numbers import warnings from collections import Counter @@ -22,6 +23,7 @@ from ...metrics.pairwise import ValueDifferenceMetric from ...utils import Substitution, check_neighbors_object, check_target_type from ...utils._docstring import _n_jobs_docstring, _random_state_docstring +from ...utils._param_validation import HasMethods, Interval from ...utils.fixes import _mode from ..base import BaseOverSampler @@ -29,6 +31,15 @@ class BaseSMOTE(BaseOverSampler): """Base class for the different SMOTE algorithms.""" + _parameter_constraints: dict = { + **BaseOverSampler._parameter_constraints, + "k_neighbors": [ + Interval(numbers.Integral, 1, None, closed="left"), + HasMethods(["kneighbors", "kneighbors_graph"]), + ], + "n_jobs": [numbers.Integral, None], + } + def __init__( self, sampling_strategy="auto", @@ -193,11 +204,9 @@ def _in_danger_noise(self, nn_estimator, samples, target_class, y, kind="danger" n_maj >= (nn_estimator.n_neighbors - 1) / 2, n_maj < nn_estimator.n_neighbors - 1, ) - elif kind == "noise": + else: # kind == "noise": # Samples are noise for m = m' return n_maj == nn_estimator.n_neighbors - 1 - else: - raise NotImplementedError @Substitution( @@ -371,7 +380,7 @@ class SMOTENC(SMOTE): Parameters ---------- - categorical_features : ndarray of shape (n_cat_features,) or (n_features,) + categorical_features : array-like of shape (n_cat_features,) or (n_features,) Specified which features are categorical. Can either be: - array of indices specifying the categorical features; @@ -489,6 +498,11 @@ class SMOTENC(SMOTE): _required_parameters = ["categorical_features"] + _parameter_constraints: dict = { + **SMOTE._parameter_constraints, + "categorical_features": ["array-like"], + } + def __init__( self, categorical_features, @@ -502,6 +516,7 @@ def __init__( sampling_strategy=sampling_strategy, random_state=random_state, k_neighbors=k_neighbors, + n_jobs=n_jobs, ) self.categorical_features = categorical_features diff --git a/imblearn/over_sampling/_smote/cluster.py b/imblearn/over_sampling/_smote/cluster.py index 46d6c7405..ccfe07a7e 100644 --- a/imblearn/over_sampling/_smote/cluster.py +++ b/imblearn/over_sampling/_smote/cluster.py @@ -6,6 +6,7 @@ # License: MIT import math +import numbers import numpy as np from scipy import sparse @@ -16,6 +17,7 @@ from ...utils import Substitution from ...utils._docstring import _n_jobs_docstring, _random_state_docstring +from ...utils._param_validation import HasMethods, Interval, StrOptions from ..base import BaseOverSampler from .base import BaseSMOTE @@ -138,6 +140,17 @@ class KMeansSMOTE(BaseSMOTE): More 0 samples: True """ + _parameter_constraints: dict = { + **BaseSMOTE._parameter_constraints, + "kmeans_estimator": [ + HasMethods(["fit", "predict"]), + Interval(numbers.Integral, 1, None, closed="left"), + None, + ], + "cluster_balance_threshold": [StrOptions({"auto"}), numbers.Real], + "density_exponent": [StrOptions({"auto"}), numbers.Real], + } + def __init__( self, *, @@ -171,15 +184,6 @@ def _validate_estimator(self): else: self.kmeans_estimator_ = clone(self.kmeans_estimator) - # validate the parameters - for param_name in ("cluster_balance_threshold", "density_exponent"): - param = getattr(self, param_name) - if isinstance(param, str) and param != "auto": - raise ValueError( - f"'{param_name}' should be 'auto' when a string is passed." - f" Got {repr(param)} instead." - ) - self.cluster_balance_threshold_ = ( self.cluster_balance_threshold if self.kmeans_estimator_.n_clusters != 1 diff --git a/imblearn/over_sampling/_smote/filter.py b/imblearn/over_sampling/_smote/filter.py index 857bb417f..cf014b9ea 100644 --- a/imblearn/over_sampling/_smote/filter.py +++ b/imblearn/over_sampling/_smote/filter.py @@ -6,6 +6,7 @@ # Dzianis Dudnik # License: MIT +import numbers import warnings import numpy as np @@ -16,6 +17,7 @@ from ...utils import Substitution, check_neighbors_object from ...utils._docstring import _n_jobs_docstring, _random_state_docstring +from ...utils._param_validation import HasMethods, Interval, StrOptions from ..base import BaseOverSampler from .base import BaseSMOTE @@ -144,6 +146,15 @@ class BorderlineSMOTE(BaseSMOTE): Resampled dataset shape Counter({{0: 900, 1: 900}}) """ + _parameter_constraints: dict = { + **BaseSMOTE._parameter_constraints, + "m_neighbors": [ + Interval(numbers.Integral, 1, None, closed="left"), + HasMethods(["kneighbors", "kneighbors_graph"]), + ], + "kind": [StrOptions({"borderline-1", "borderline-2"})], + } + def __init__( self, *, @@ -168,12 +179,6 @@ def _validate_estimator(self): self.nn_m_ = check_neighbors_object( "m_neighbors", self.m_neighbors, additional_neighbor=1 ) - if self.kind not in ("borderline-1", "borderline-2"): - raise ValueError( - f'The possible "kind" of algorithm are ' - f'"borderline-1" and "borderline-2".' - f"Got {self.kind} instead." - ) def _fit_resample(self, X, y): # FIXME: to be removed in 0.12 @@ -396,6 +401,16 @@ class SVMSMOTE(BaseSMOTE): Resampled dataset shape Counter({{0: 900, 1: 900}}) """ + _parameter_constraints: dict = { + **BaseSMOTE._parameter_constraints, + "m_neighbors": [ + Interval(numbers.Integral, 1, None, closed="left"), + HasMethods(["kneighbors", "kneighbors_graph"]), + ], + "svm_estimator": [HasMethods(["fit", "predict"]), None], + "out_step": [Interval(numbers.Real, 0, 1, closed="both")], + } + def __init__( self, *, diff --git a/imblearn/over_sampling/_smote/tests/test_borderline_smote.py b/imblearn/over_sampling/_smote/tests/test_borderline_smote.py index f15420384..7519fcaab 100644 --- a/imblearn/over_sampling/_smote/tests/test_borderline_smote.py +++ b/imblearn/over_sampling/_smote/tests/test_borderline_smote.py @@ -36,12 +36,6 @@ def data(): return X, y -def test_borderline_smote_wrong_kind(data): - bsmote = BorderlineSMOTE(kind="rand") - with pytest.raises(ValueError, match='The possible "kind" of algorithm'): - bsmote.fit_resample(*data) - - @pytest.mark.parametrize("kind", ["borderline-1", "borderline-2"]) def test_borderline_smote(kind, data): bsmote = BorderlineSMOTE(kind=kind, random_state=42) diff --git a/imblearn/over_sampling/_smote/tests/test_kmeans_smote.py b/imblearn/over_sampling/_smote/tests/test_kmeans_smote.py index 033d9b3a0..71fa47c66 100644 --- a/imblearn/over_sampling/_smote/tests/test_kmeans_smote.py +++ b/imblearn/over_sampling/_smote/tests/test_kmeans_smote.py @@ -106,17 +106,3 @@ def test_sample_kmeans_density_estimation(density_exponent, cluster_balance_thre cluster_balance_threshold=cluster_balance_threshold, ) smote.fit_resample(X, y) - - -@pytest.mark.parametrize( - "density_exponent, cluster_balance_threshold", - [("xxx", "auto"), ("auto", "xxx")], -) -def test_kmeans_smote_param_error(data, density_exponent, cluster_balance_threshold): - X, y = data - kmeans_smote = KMeansSMOTE( - density_exponent=density_exponent, - cluster_balance_threshold=cluster_balance_threshold, - ) - with pytest.raises(ValueError, match="should be 'auto' when a string"): - kmeans_smote.fit_resample(X, y) diff --git a/imblearn/over_sampling/base.py b/imblearn/over_sampling/base.py index 7165ab9c9..4bc08e91a 100644 --- a/imblearn/over_sampling/base.py +++ b/imblearn/over_sampling/base.py @@ -5,7 +5,10 @@ # Christos Aridas # License: MIT +import numbers + from ..base import BaseSampler +from ..utils._param_validation import Interval, StrOptions class BaseOverSampler(BaseSampler): @@ -53,3 +56,13 @@ class BaseOverSampler(BaseSampler): correspond to the targeted classes. The values correspond to the desired number of samples for each class. """.strip() # noqa: E501 + + _parameter_constraints: dict = { + "sampling_strategy": [ + Interval(numbers.Real, 0, 1, closed="right"), + StrOptions({"auto", "majority", "not minority", "not majority", "all"}), + dict, + callable, + ], + "random_state": ["random_state"], + } diff --git a/imblearn/over_sampling/tests/test_adasyn.py b/imblearn/over_sampling/tests/test_adasyn.py index b6397041c..4df636273 100644 --- a/imblearn/over_sampling/tests/test_adasyn.py +++ b/imblearn/over_sampling/tests/test_adasyn.py @@ -4,7 +4,6 @@ # License: MIT import numpy as np -import pytest from sklearn.neighbors import NearestNeighbors from sklearn.utils._testing import assert_allclose, assert_array_equal @@ -120,23 +119,3 @@ def test_ada_fit_resample_nn_obj(): ) assert_allclose(X_resampled, X_gt, rtol=R_TOL) assert_array_equal(y_resampled, y_gt) - - -@pytest.mark.parametrize( - "adasyn_params, err_msg", - [ - ( - {"sampling_strategy": {0: 9, 1: 12}}, - "No samples will be generated.", - ), - ( - {"n_neighbors": "rnd"}, - "n_neighbors must be an interger or an object compatible with the " - "KNeighborsMixin API of scikit-learn", - ), - ], -) -def test_adasyn_error(adasyn_params, err_msg): - adasyn = ADASYN(**adasyn_params) - with pytest.raises(ValueError, match=err_msg): - adasyn.fit_resample(X, Y) diff --git a/imblearn/over_sampling/tests/test_random_over_sampler.py b/imblearn/over_sampling/tests/test_random_over_sampler.py index e7663d168..2db808f5b 100644 --- a/imblearn/over_sampling/tests/test_random_over_sampler.py +++ b/imblearn/over_sampling/tests/test_random_over_sampler.py @@ -44,7 +44,9 @@ def test_ros_init(): assert ros.random_state == RND_SEED -@pytest.mark.parametrize("params", [{"shrinkage": None}, {"shrinkage": 0}]) +@pytest.mark.parametrize( + "params", [{"shrinkage": None}, {"shrinkage": 0}, {"shrinkage": {0: 0}}] +) @pytest.mark.parametrize("X_type", ["array", "dataframe"]) def test_ros_fit_resample(X_type, data, params): X, Y = data @@ -244,14 +246,7 @@ def test_random_over_sampler_shrinkage_behaviour(data): "shrinkage, err_msg", [ ({}, "`shrinkage` should contain a shrinkage factor for each class"), - (-1, "The shrinkage factor needs to be >= 0"), ({0: -1}, "The shrinkage factor needs to be >= 0"), - ( - [ - 1, - ], - "`shrinkage` should either be a positive floating number or", - ), ], ) def test_random_over_sampler_shrinkage_error(data, shrinkage, err_msg): diff --git a/imblearn/tests/test_common.py b/imblearn/tests/test_common.py index f1d780ba2..9ec5764d3 100644 --- a/imblearn/tests/test_common.py +++ b/imblearn/tests/test_common.py @@ -7,7 +7,7 @@ from sklearn.base import clone from sklearn.exceptions import ConvergenceWarning from sklearn.utils._testing import SkipTest, ignore_warnings, set_random_state -from sklearn.utils.estimator_checks import _construct_instance +from sklearn.utils.estimator_checks import _construct_instance, _get_check_estimator_ids from sklearn.utils.estimator_checks import ( parametrize_with_checks as parametrize_with_checks_sklearn, ) @@ -15,6 +15,7 @@ from imblearn.under_sampling import NearMiss from imblearn.utils.estimator_checks import ( _set_checking_parameters, + check_param_validation, parametrize_with_checks, ) from imblearn.utils.testing import all_estimators @@ -62,3 +63,13 @@ def test_estimators_imblearn(estimator, check, request): ): _set_checking_parameters(estimator) check(estimator) + + +@pytest.mark.parametrize( + "estimator", _tested_estimators(), ids=_get_check_estimator_ids +) +def test_check_param_validation(estimator): + name = estimator.__class__.__name__ + print(name) + _set_checking_parameters(estimator) + check_param_validation(name, estimator) diff --git a/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py b/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py index faa558600..5be949ed5 100644 --- a/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py +++ b/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py @@ -15,6 +15,7 @@ from ...utils import Substitution from ...utils._docstring import _random_state_docstring +from ...utils._param_validation import HasMethods, StrOptions from ..base import BaseUnderSampler VOTING_KIND = ("auto", "hard", "soft") @@ -107,6 +108,13 @@ class ClusterCentroids(BaseUnderSampler): Resampled dataset shape Counter({{...}}) """ + _parameter_constraints: dict = { + **BaseUnderSampler._parameter_constraints, + "estimator": [HasMethods(["fit", "predict"]), None], + "voting": [StrOptions({"auto", "hard", "soft"})], + "random_state": ["random_state"], + } + def __init__( self, *, @@ -151,18 +159,9 @@ def _fit_resample(self, X, y): self._validate_estimator() if self.voting == "auto": - if sparse.issparse(X): - self.voting_ = "hard" - else: - self.voting_ = "soft" + self.voting_ = "hard" if sparse.issparse(X) else "soft" else: - if self.voting in VOTING_KIND: - self.voting_ = self.voting - else: - raise ValueError( - f"'voting' needs to be one of {VOTING_KIND}. " - f"Got {self.voting} instead." - ) + self.voting_ = self.voting X_resampled, y_resampled = [], [] for target_class in np.unique(y): diff --git a/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py b/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py index 86a221606..b51e3501c 100644 --- a/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py +++ b/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py @@ -103,18 +103,6 @@ def test_fit_hard_voting(): assert np.any(np.all(x == X, axis=1)) -@pytest.mark.parametrize( - "cluster_centroids_params, err_msg", - [ - ({"voting": "unknown"}, "needs to be one of"), - ], -) -def test_fit_resample_error(cluster_centroids_params, err_msg): - cc = ClusterCentroids(**cluster_centroids_params) - with pytest.raises(ValueError, match=err_msg): - cc.fit_resample(X, Y) - - @pytest.mark.filterwarnings("ignore:The default value of `n_init` will change") def test_cluster_centroids_hard_target_class(): # check that the samples selecting by the hard voting corresponds to the @@ -150,19 +138,26 @@ def test_cluster_centroids_hard_target_class(): assert sum(sample_from_minority_in_majority) == 0 -def test_cluster_centroids_error_estimator(): - """Check that an error is raised when estimator does not have a cluster API.""" - - err_msg = ( - "`estimator` should be a clustering estimator exposing a parameter " - "`n_clusters` and a fitted parameter `cluster_centers_`." - ) - with pytest.raises(ValueError, match=err_msg): - ClusterCentroids(estimator=LogisticRegression()).fit_resample(X, Y) +def test_cluster_centroids_custom_clusterer(): + clusterer = _CustomClusterer() + cc = ClusterCentroids(estimator=clusterer, random_state=RND_SEED) + cc.fit_resample(X, Y) + assert isinstance(cc.estimator_.cluster_centers_, np.ndarray) + clusterer = _CustomClusterer(expose_cluster_centers=False) + cc = ClusterCentroids(estimator=clusterer, random_state=RND_SEED) err_msg = ( "`estimator` should be a clustering estimator exposing a fitted parameter " "`cluster_centers_`." ) with pytest.raises(RuntimeError, match=err_msg): - ClusterCentroids(estimator=_CustomClusterer()).fit_resample(X, Y) + cc.fit_resample(X, Y) + + clusterer = LogisticRegression() + cc = ClusterCentroids(estimator=clusterer, random_state=RND_SEED) + err_msg = ( + "`estimator` should be a clustering estimator exposing a parameter " + "`n_clusters` and a fitted parameter `cluster_centers_`." + ) + with pytest.raises(ValueError, match=err_msg): + cc.fit_resample(X, Y) diff --git a/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py b/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py index c089a6493..b0d9109cf 100644 --- a/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py +++ b/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py @@ -5,6 +5,7 @@ # Christos Aridas # License: MIT +import numbers from collections import Counter import numpy as np @@ -15,6 +16,7 @@ from ...utils import Substitution from ...utils._docstring import _n_jobs_docstring, _random_state_docstring +from ...utils._param_validation import HasMethods, Interval from ..base import BaseCleaningSampler @@ -104,6 +106,18 @@ class CondensedNearestNeighbour(BaseCleaningSampler): Resampled dataset shape Counter({{-1: 268, 1: 227}}) # doctest: +SKIP """ + _parameter_constraints: dict = { + **BaseCleaningSampler._parameter_constraints, + "n_neighbors": [ + Interval(numbers.Integral, 1, None, closed="left"), + HasMethods(["kneighbors", "kneighbors_graph"]), + None, + ], + "n_seeds_S": [Interval(numbers.Integral, 1, None, closed="left")], + "n_jobs": [numbers.Integral, None], + "random_state": ["random_state"], + } + def __init__( self, *, @@ -123,18 +137,12 @@ def _validate_estimator(self): """Private function to create the NN estimator""" if self.n_neighbors is None: self.estimator_ = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs) - elif isinstance(self.n_neighbors, int): + elif isinstance(self.n_neighbors, numbers.Integral): self.estimator_ = KNeighborsClassifier( n_neighbors=self.n_neighbors, n_jobs=self.n_jobs ) elif isinstance(self.n_neighbors, KNeighborsClassifier): self.estimator_ = clone(self.n_neighbors) - else: - raise ValueError( - f"`n_neighbors` has to be a int or an object" - f" inhereited from KNeighborsClassifier." - f" Got {type(self.n_neighbors)} instead." - ) def _fit_resample(self, X, y): self._validate_estimator() diff --git a/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py b/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py index e016ea118..84694e746 100644 --- a/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py +++ b/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py @@ -6,6 +6,7 @@ # Christos Aridas # License: MIT +import numbers from collections import Counter import numpy as np @@ -13,6 +14,7 @@ from ...utils import Substitution, check_neighbors_object from ...utils._docstring import _n_jobs_docstring +from ...utils._param_validation import HasMethods, Interval, StrOptions from ...utils.fixes import _mode from ..base import BaseCleaningSampler @@ -112,6 +114,16 @@ class EditedNearestNeighbours(BaseCleaningSampler): Resampled dataset shape Counter({{1: 887, 0: 100}}) """ + _parameter_constraints: dict = { + **BaseCleaningSampler._parameter_constraints, + "n_neighbors": [ + Interval(numbers.Integral, 1, None, closed="left"), + HasMethods(["kneighbors", "kneighbors_graph"]), + ], + "kind_sel": [StrOptions({"all", "mode"})], + "n_jobs": [numbers.Integral, None], + } + def __init__( self, *, @@ -132,9 +144,6 @@ def _validate_estimator(self): ) self.nn_.set_params(**{"n_jobs": self.n_jobs}) - if self.kind_sel not in SEL_KIND: - raise NotImplementedError - def _fit_resample(self, X, y): self._validate_estimator() @@ -279,6 +288,17 @@ class RepeatedEditedNearestNeighbours(BaseCleaningSampler): Resampled dataset shape Counter({{1: 887, 0: 100}}) """ + _parameter_constraints: dict = { + **BaseCleaningSampler._parameter_constraints, + "n_neighbors": [ + Interval(numbers.Integral, 1, None, closed="left"), + HasMethods(["kneighbors", "kneighbors_graph"]), + ], + "max_iter": [Interval(numbers.Integral, 1, None, closed="left")], + "kind_sel": [StrOptions({"all", "mode"})], + "n_jobs": [numbers.Integral, None], + } + def __init__( self, *, @@ -296,12 +316,6 @@ def __init__( def _validate_estimator(self): """Private function to create the NN estimator""" - if self.max_iter < 2: - raise ValueError( - f"max_iter must be greater than 1." - f" Got {type(self.max_iter)} instead." - ) - self.nn_ = check_neighbors_object( "n_neighbors", self.n_neighbors, additional_neighbor=1 ) @@ -477,6 +491,17 @@ class without early stopping. Resampled dataset shape Counter({{1: 887, 0: 100}}) """ + _parameter_constraints: dict = { + **BaseCleaningSampler._parameter_constraints, + "n_neighbors": [ + Interval(numbers.Integral, 1, None, closed="left"), + HasMethods(["kneighbors", "kneighbors_graph"]), + ], + "kind_sel": [StrOptions({"all", "mode"})], + "allow_minority": ["boolean"], + "n_jobs": [numbers.Integral, None], + } + def __init__( self, *, @@ -494,9 +519,6 @@ def __init__( def _validate_estimator(self): """Create objects required by AllKNN""" - if self.kind_sel not in SEL_KIND: - raise NotImplementedError - self.nn_ = check_neighbors_object( "n_neighbors", self.n_neighbors, additional_neighbor=1 ) diff --git a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py index 777720056..b1a6e1150 100644 --- a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py +++ b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py @@ -6,6 +6,7 @@ # Christos Aridas # License: MIT +import numbers from collections import Counter import numpy as np @@ -17,6 +18,7 @@ from ...utils import Substitution from ...utils._docstring import _n_jobs_docstring, _random_state_docstring +from ...utils._param_validation import HasMethods from ..base import BaseUnderSampler @@ -100,6 +102,17 @@ class InstanceHardnessThreshold(BaseUnderSampler): Resampled dataset shape Counter({{1: 5..., 0: 100}}) """ + _parameter_constraints: dict = { + **BaseUnderSampler._parameter_constraints, + "estimator": [ + HasMethods(["fit", "predict_proba"]), + None, + ], + "cv": ["cv_object"], + "n_jobs": [numbers.Integral, None], + "random_state": ["random_state"], + } + def __init__( self, *, @@ -132,10 +145,6 @@ def _validate_estimator(self, random_state): random_state=self.random_state, n_jobs=self.n_jobs, ) - else: - raise ValueError( - f"Invalid parameter `estimator`. Got {type(self.estimator)}." - ) def _fit_resample(self, X, y): random_state = check_random_state(self.random_state) diff --git a/imblearn/under_sampling/_prototype_selection/_nearmiss.py b/imblearn/under_sampling/_prototype_selection/_nearmiss.py index 7eead5f90..83f94d890 100644 --- a/imblearn/under_sampling/_prototype_selection/_nearmiss.py +++ b/imblearn/under_sampling/_prototype_selection/_nearmiss.py @@ -4,6 +4,7 @@ # Christos Aridas # License: MIT +import numbers import warnings from collections import Counter @@ -12,6 +13,7 @@ from ...utils import Substitution, check_neighbors_object from ...utils._docstring import _n_jobs_docstring +from ...utils._param_validation import HasMethods, Interval from ..base import BaseUnderSampler @@ -104,6 +106,20 @@ class NearMiss(BaseUnderSampler): Resampled dataset shape Counter({{0: 100, 1: 100}}) """ + _parameter_constraints: dict = { + **BaseUnderSampler._parameter_constraints, + "version": [Interval(numbers.Integral, 1, 3, closed="both")], + "n_neighbors": [ + Interval(numbers.Integral, 1, None, closed="left"), + HasMethods(["kneighbors", "kneighbors_graph"]), + ], + "n_neighbors_ver3": [ + Interval(numbers.Integral, 1, None, closed="left"), + HasMethods(["kneighbors", "kneighbors_graph"]), + ], + "n_jobs": [numbers.Integral, None], + } + def __init__( self, *, @@ -166,10 +182,8 @@ def _selection_dist_based( # Sort the list of distance and get the index if sel_strategy == "nearest": sort_way = False - elif sel_strategy == "farthest": + else: # sel_strategy == "farthest": sort_way = True - else: - raise NotImplementedError sorted_idx = sorted( range(len(dist_avg_vec)), @@ -202,11 +216,6 @@ def _validate_estimator(self): ) self.nn_ver3_.set_params(**{"n_jobs": self.n_jobs}) - if self.version not in (1, 2, 3): - raise ValueError( - f"Parameter `version` must be 1, 2 or 3, got {self.version}" - ) - def _fit_resample(self, X, y): self._validate_estimator() diff --git a/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py b/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py index 80b9aa6b1..00be9ca71 100644 --- a/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py +++ b/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py @@ -4,6 +4,7 @@ # Christos Aridas # License: MIT +import numbers from collections import Counter import numpy as np @@ -11,6 +12,7 @@ from ...utils import Substitution, check_neighbors_object from ...utils._docstring import _n_jobs_docstring +from ...utils._param_validation import HasMethods, Interval, StrOptions from ...utils.fixes import _mode from ..base import BaseCleaningSampler from ._edited_nearest_neighbours import EditedNearestNeighbours @@ -113,6 +115,17 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler): Resampled dataset shape Counter({{1: 877, 0: 100}}) """ + _parameter_constraints: dict = { + **BaseCleaningSampler._parameter_constraints, + "n_neighbors": [ + Interval(numbers.Integral, 1, None, closed="left"), + HasMethods(["kneighbors", "kneighbors_graph"]), + ], + "kind_sel": [StrOptions({"all", "mode"})], + "threshold_cleaning": [Interval(numbers.Real, 0, 1, closed="neither")], + "n_jobs": [numbers.Integral, None], + } + def __init__( self, *, @@ -135,15 +148,6 @@ def _validate_estimator(self): ) self.nn_.set_params(**{"n_jobs": self.n_jobs}) - if self.kind_sel not in SEL_KIND: - raise NotImplementedError - - if self.threshold_cleaning > 1 or self.threshold_cleaning < 0: - raise ValueError( - f"'threshold_cleaning' is a value between 0 and 1." - f" Got {self.threshold_cleaning} instead." - ) - def _fit_resample(self, X, y): self._validate_estimator() enn = EditedNearestNeighbours( @@ -179,11 +183,9 @@ def _fit_resample(self, X, y): if self.kind_sel == "mode": nnhood_label_majority, _ = _mode(nnhood_label, axis=1) nnhood_bool = np.ravel(nnhood_label_majority) == y_class - elif self.kind_sel == "all": + else: # self.kind_sel == "all": nnhood_label_majority = nnhood_label == class_minority nnhood_bool = np.all(nnhood_label, axis=1) - else: - raise NotImplementedError # compute a2 group index_a2 = np.ravel(nnhood_idx[~nnhood_bool]) index_a2 = np.unique( diff --git a/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py b/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py index d714e2971..0a1866075 100644 --- a/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py +++ b/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py @@ -4,6 +4,7 @@ # Christos Aridas # License: MIT +import numbers from collections import Counter import numpy as np @@ -13,6 +14,7 @@ from ...utils import Substitution from ...utils._docstring import _n_jobs_docstring, _random_state_docstring +from ...utils._param_validation import HasMethods, Interval from ..base import BaseCleaningSampler from ._tomek_links import TomekLinks @@ -100,6 +102,18 @@ class OneSidedSelection(BaseCleaningSampler): Resampled dataset shape Counter({{1: 496, 0: 100}}) """ + _parameter_constraints: dict = { + **BaseCleaningSampler._parameter_constraints, + "n_neighbors": [ + Interval(numbers.Integral, 1, None, closed="left"), + HasMethods(["kneighbors", "kneighbors_graph"]), + None, + ], + "n_seeds_S": [Interval(numbers.Integral, 1, None, closed="left")], + "n_jobs": [numbers.Integral, None], + "random_state": ["random_state"], + } + def __init__( self, *, @@ -125,12 +139,6 @@ def _validate_estimator(self): ) elif isinstance(self.n_neighbors, KNeighborsClassifier): self.estimator_ = clone(self.n_neighbors) - else: - raise ValueError( - f"`n_neighbors` has to be a int or an object" - f" inherited from KNeighborsClassifier." - f" Got {type(self.n_neighbors)} instead." - ) def _fit_resample(self, X, y): self._validate_estimator() diff --git a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py index c575ba400..a7c735fa6 100644 --- a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py +++ b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py @@ -76,6 +76,12 @@ class RandomUnderSampler(BaseUnderSampler): Resampled dataset shape Counter({{0: 100, 1: 100}}) """ + _parameter_constraints: dict = { + **BaseUnderSampler._parameter_constraints, + "replacement": ["boolean"], + "random_state": ["random_state"], + } + def __init__( self, *, sampling_strategy="auto", random_state=None, replacement=False ): diff --git a/imblearn/under_sampling/_prototype_selection/_tomek_links.py b/imblearn/under_sampling/_prototype_selection/_tomek_links.py index 3bfb61f7d..31d62675b 100644 --- a/imblearn/under_sampling/_prototype_selection/_tomek_links.py +++ b/imblearn/under_sampling/_prototype_selection/_tomek_links.py @@ -5,6 +5,8 @@ # Christos Aridas # License: MIT +import numbers + import numpy as np from sklearn.neighbors import NearestNeighbors from sklearn.utils import _safe_indexing @@ -82,6 +84,11 @@ class TomekLinks(BaseCleaningSampler): Resampled dataset shape Counter({{1: 897, 0: 100}}) """ + _parameter_constraints: dict = { + **BaseCleaningSampler._parameter_constraints, + "n_jobs": [numbers.Integral, None], + } + def __init__(self, *, sampling_strategy="auto", n_jobs=None): super().__init__(sampling_strategy=sampling_strategy) self.n_jobs = n_jobs diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_condensed_nearest_neighbour.py b/imblearn/under_sampling/_prototype_selection/tests/test_condensed_nearest_neighbour.py index ea98906b8..5b41b8777 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_condensed_nearest_neighbour.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_condensed_nearest_neighbour.py @@ -68,9 +68,9 @@ def test_cnn_fit_resample(): assert_array_equal(y_resampled, y_gt) -def test_cnn_fit_resample_with_object(): - knn = KNeighborsClassifier(n_neighbors=1) - cnn = CondensedNearestNeighbour(random_state=RND_SEED, n_neighbors=knn) +@pytest.mark.parametrize("n_neighbors", [1, KNeighborsClassifier(n_neighbors=1)]) +def test_cnn_fit_resample_with_object(n_neighbors): + cnn = CondensedNearestNeighbour(random_state=RND_SEED, n_neighbors=n_neighbors) X_resampled, y_resampled = cnn.fit_resample(X, Y) X_gt = np.array( @@ -95,10 +95,3 @@ def test_cnn_fit_resample_with_object(): X_resampled, y_resampled = cnn.fit_resample(X, Y) assert_array_equal(X_resampled, X_gt) assert_array_equal(y_resampled, y_gt) - - -def test_cnn_fit_resample_with_wrong_object(): - knn = "rnd" - cnn = CondensedNearestNeighbour(random_state=RND_SEED, n_neighbors=knn) - with pytest.raises(ValueError, match="has to be a int or an "): - cnn.fit_resample(X, Y) diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_edited_nearest_neighbours.py b/imblearn/under_sampling/_prototype_selection/tests/test_edited_nearest_neighbours.py index 914cfd7ec..00a0ce599 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_edited_nearest_neighbours.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_edited_nearest_neighbours.py @@ -4,7 +4,6 @@ # License: MIT import numpy as np -import pytest from sklearn.datasets import make_classification from sklearn.neighbors import NearestNeighbors from sklearn.utils._testing import assert_array_equal @@ -121,17 +120,6 @@ def test_enn_fit_resample_with_nn_object(): assert_array_equal(y_resampled, y_gt) -def test_enn_not_good_object(): - nn = "rnd" - enn = EditedNearestNeighbours(n_neighbors=nn, kind_sel="mode") - err_msg = ( - "n_neighbors must be an interger or an object compatible with the " - "KNeighborsMixin API of scikit-learn" - ) - with pytest.raises(ValueError, match=err_msg): - enn.fit_resample(X, Y) - - def test_enn_check_kind_selection(): """Check that `check_sel="all"` is more conservative than `check_sel="mode"`.""" diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py b/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py index 3d815d197..5d7008747 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py @@ -4,7 +4,6 @@ # License: MIT import numpy as np -import pytest from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier from sklearn.naive_bayes import GaussianNB as NB from sklearn.utils._testing import assert_array_equal @@ -74,15 +73,6 @@ def test_iht_fit_resample_class_obj(): assert y_resampled.shape == (12,) -def test_iht_fit_resample_wrong_class_obj(): - from sklearn.cluster import KMeans - - est = KMeans() - iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED) - with pytest.raises(ValueError, match="Invalid parameter `estimator`"): - iht.fit_resample(X, Y) - - def test_iht_reproducibility(): from sklearn.datasets import load_digits @@ -95,3 +85,11 @@ def test_iht_reproducibility(): idx_sampled.append(iht.sample_indices_.copy()) for idx_1, idx_2 in zip(idx_sampled, idx_sampled[1:]): assert_array_equal(idx_1, idx_2) + + +def test_iht_fit_resample_default_estimator(): + iht = InstanceHardnessThreshold(estimator=None, random_state=RND_SEED) + X_resampled, y_resampled = iht.fit_resample(X, Y) + assert isinstance(iht.estimator_, RandomForestClassifier) + assert X_resampled.shape == (12, 2) + assert y_resampled.shape == (12,) diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_nearmiss.py b/imblearn/under_sampling/_prototype_selection/tests/test_nearmiss.py index cc8678a59..9ab0da4f3 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_nearmiss.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_nearmiss.py @@ -4,7 +4,6 @@ # License: MIT import numpy as np -import pytest from sklearn.neighbors import NearestNeighbors from sklearn.utils._testing import assert_array_equal @@ -34,30 +33,6 @@ VERSION_NEARMISS = (1, 2, 3) -@pytest.mark.parametrize( - "nearmiss_params, err_msg", - [ - ({"version": 1000}, "must be 1, 2 or 3"), - ( - {"version": 1, "n_neighbors": "rnd"}, - "n_neighbors must be an interger or an object compatible", - ), - ( - { - "version": 3, - "n_neighbors": NearestNeighbors(n_neighbors=3), - "n_neighbors_ver3": "rnd", - }, - "n_neighbors_ver3 must be an interger or an object compatible", - ), - ], -) -def test_nearmiss_error(nearmiss_params, err_msg): - nm = NearMiss(**nearmiss_params) - with pytest.raises(ValueError, match=err_msg): - nm.fit_resample(X, Y) - - def test_nm_fit_resample_auto(): sampling_strategy = "auto" X_gt = [ diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py b/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py index a37c61953..971d5b559 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py @@ -4,7 +4,6 @@ # License: MIT import numpy as np -import pytest from sklearn.utils._testing import assert_array_equal from imblearn.under_sampling import NeighbourhoodCleaningRule @@ -31,23 +30,6 @@ Y = np.array([1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 0, 0, 2, 1, 2]) -@pytest.mark.parametrize( - "ncr_params, err_msg", - [ - ({"threshold_cleaning": -10}, "value between 0 and 1"), - ({"threshold_cleaning": 10}, "value between 0 and 1"), - ( - {"n_neighbors": "rnd"}, - "n_neighbors must be an interger or an object compatible", - ), - ], -) -def test_ncr_error(ncr_params, err_msg): - ncr = NeighbourhoodCleaningRule(**ncr_params) - with pytest.raises(ValueError, match=err_msg): - ncr.fit_resample(X, Y) - - def test_ncr_fit_resample(): ncr = NeighbourhoodCleaningRule() X_resampled, y_resampled = ncr.fit_resample(X, Y) diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_one_sided_selection.py b/imblearn/under_sampling/_prototype_selection/tests/test_one_sided_selection.py index 3bbdb736a..7d3adde0f 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_one_sided_selection.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_one_sided_selection.py @@ -66,9 +66,9 @@ def test_oss_fit_resample(): assert_array_equal(y_resampled, y_gt) -def test_oss_with_object(): - knn = KNeighborsClassifier(n_neighbors=1) - oss = OneSidedSelection(random_state=RND_SEED, n_neighbors=knn) +@pytest.mark.parametrize("n_neighbors", [1, KNeighborsClassifier(n_neighbors=1)]) +def test_oss_with_object(n_neighbors): + oss = OneSidedSelection(random_state=RND_SEED, n_neighbors=n_neighbors) X_resampled, y_resampled = oss.fit_resample(X, Y) X_gt = np.array( @@ -95,10 +95,3 @@ def test_oss_with_object(): X_resampled, y_resampled = oss.fit_resample(X, Y) assert_array_equal(X_resampled, X_gt) assert_array_equal(y_resampled, y_gt) - - -def test_oss_with_wrong_object(): - knn = "rnd" - oss = OneSidedSelection(random_state=RND_SEED, n_neighbors=knn) - with pytest.raises(ValueError, match="has to be a int"): - oss.fit_resample(X, Y) diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_repeated_edited_nearest_neighbours.py b/imblearn/under_sampling/_prototype_selection/tests/test_repeated_edited_nearest_neighbours.py index b7c7301a2..edd3a9132 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_repeated_edited_nearest_neighbours.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_repeated_edited_nearest_neighbours.py @@ -328,13 +328,6 @@ def test_renn_fit_resample_mode(): assert 0 < renn.n_iter_ <= renn.max_iter -def test_renn_not_good_object(): - nn = "rnd" - renn = RepeatedEditedNearestNeighbours(n_neighbors=nn, kind_sel="mode") - with pytest.raises(ValueError): - renn.fit_resample(X, Y) - - @pytest.mark.parametrize( "max_iter, n_iter", [(2, 2), (5, 3)], diff --git a/imblearn/under_sampling/base.py b/imblearn/under_sampling/base.py index 82db6bdd0..e36d8c31f 100644 --- a/imblearn/under_sampling/base.py +++ b/imblearn/under_sampling/base.py @@ -4,7 +4,10 @@ # Authors: Guillaume Lemaitre # License: MIT +import numbers + from ..base import BaseSampler +from ..utils._param_validation import Interval, StrOptions class BaseUnderSampler(BaseSampler): @@ -54,6 +57,15 @@ class BaseUnderSampler(BaseSampler): desired number of samples for each class. """.rstrip() # noqa: E501 + _parameter_constraints: dict = { + "sampling_strategy": [ + Interval(numbers.Real, 0, 1, closed="right"), + StrOptions({"auto", "majority", "not minority", "not majority", "all"}), + dict, + callable, + ], + } + class BaseCleaningSampler(BaseSampler): """Base class for under-sampling algorithms. @@ -88,3 +100,12 @@ class BaseCleaningSampler(BaseSampler): correspond to the targeted classes. The values correspond to the desired number of samples for each class. """.rstrip() + + _parameter_constraints: dict = { + "sampling_strategy": [ + Interval(numbers.Real, 0, 1, closed="right"), + StrOptions({"auto", "majority", "not minority", "not majority", "all"}), + list, + callable, + ], + } diff --git a/imblearn/utils/_param_validation.py b/imblearn/utils/_param_validation.py new file mode 100644 index 000000000..1883d2bac --- /dev/null +++ b/imblearn/utils/_param_validation.py @@ -0,0 +1,926 @@ +"""This is a copy of sklearn/utils/_param_validation.py. It can be removed when +we support scikit-learn >= 1.2. +""" +# mypy: ignore-errors +import functools +import math +import operator +import warnings +from abc import ABC, abstractmethod +from collections.abc import Iterable +from inspect import signature +from numbers import Integral, Real + +import numpy as np +import sklearn +from scipy.sparse import csr_matrix, issparse +from sklearn.utils.fixes import parse_version + +from ..utils.fixes import _is_arraylike_not_scalar + +sklearn_version = parse_version(sklearn.__version__) + +if sklearn_version < parse_version("1.2"): + + def validate_parameter_constraints(parameter_constraints, params, caller_name): + """Validate types and values of given parameters. + + Parameters + ---------- + parameter_constraints : dict or {"no_validation"} + If "no_validation", validation is skipped for this parameter. + + If a dict, it must be a dictionary `param_name: list of constraints`. + A parameter is valid if it satisfies one of the constraints from the list. + Constraints can be: + - an Interval object, representing a continuous or discrete range of numbers + - the string "array-like" + - the string "sparse matrix" + - the string "random_state" + - callable + - None, meaning that None is a valid value for the parameter + - any type, meaning that any instance of this type is valid + - an Options object, representing a set of elements of a given type + - a StrOptions object, representing a set of strings + - the string "boolean" + - the string "verbose" + - the string "cv_object" + - the string "missing_values" + - a HasMethods object, representing method(s) an object must have + - a Hidden object, representing a constraint not meant to be exposed to the + user + + params : dict + A dictionary `param_name: param_value`. The parameters to validate against + the constraints. + + caller_name : str + The name of the estimator or function or method that called this function. + """ + for param_name, param_val in params.items(): + # We allow parameters to not have a constraint so that third party + # estimators can inherit from sklearn estimators without having to + # necessarily use the validation tools. + if param_name not in parameter_constraints: + continue + + constraints = parameter_constraints[param_name] + + if constraints == "no_validation": + continue + + constraints = [make_constraint(constraint) for constraint in constraints] + + for constraint in constraints: + if constraint.is_satisfied_by(param_val): + # this constraint is satisfied, no need to check further. + break + else: + # No constraint is satisfied, raise with an informative message. + + # Ignore constraints that we don't want to expose in the error message, + # i.e. options that are for internal purpose or not officially + # supported. + constraints = [ + constraint for constraint in constraints if not constraint.hidden + ] + + if len(constraints) == 1: + constraints_str = f"{constraints[0]}" + else: + constraints_str = ( + f"{', '.join([str(c) for c in constraints[:-1]])} or" + f" {constraints[-1]}" + ) + + raise ValueError( + f"The {param_name!r} parameter of {caller_name} must be" + f" {constraints_str}. Got {param_val!r} instead." + ) + + def make_constraint(constraint): + """Convert the constraint into the appropriate Constraint object. + + Parameters + ---------- + constraint : object + The constraint to convert. + + Returns + ------- + constraint : instance of _Constraint + The converted constraint. + """ + if isinstance(constraint, str) and constraint == "array-like": + return _ArrayLikes() + if isinstance(constraint, str) and constraint == "sparse matrix": + return _SparseMatrices() + if isinstance(constraint, str) and constraint == "random_state": + return _RandomStates() + if constraint is callable: + return _Callables() + if constraint is None: + return _NoneConstraint() + if isinstance(constraint, type): + return _InstancesOf(constraint) + if isinstance(constraint, (Interval, StrOptions, Options, HasMethods)): + return constraint + if isinstance(constraint, str) and constraint == "boolean": + return _Booleans() + if isinstance(constraint, str) and constraint == "verbose": + return _VerboseHelper() + if isinstance(constraint, str) and constraint == "missing_values": + return _MissingValues() + if isinstance(constraint, str) and constraint == "cv_object": + return _CVObjects() + if isinstance(constraint, Hidden): + constraint = make_constraint(constraint.constraint) + constraint.hidden = True + return constraint + raise ValueError(f"Unknown constraint type: {constraint}") + + def validate_params(parameter_constraints): + """Decorator to validate types and values of functions and methods. + + Parameters + ---------- + parameter_constraints : dict + A dictionary `param_name: list of constraints`. See the docstring of + `validate_parameter_constraints` for a description of the accepted + constraints. + + Note that the *args and **kwargs parameters are not validated and must not + be present in the parameter_constraints dictionary. + + Returns + ------- + decorated_function : function or method + The decorated function. + """ + + def decorator(func): + # The dict of parameter constraints is set as an attribute of the function + # to make it possible to dynamically introspect the constraints for + # automatic testing. + setattr(func, "_skl_parameter_constraints", parameter_constraints) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + + func_sig = signature(func) + + # Map *args/**kwargs to the function signature + params = func_sig.bind(*args, **kwargs) + params.apply_defaults() + + # ignore self/cls and positional/keyword markers + to_ignore = [ + p.name + for p in func_sig.parameters.values() + if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD) + ] + to_ignore += ["self", "cls"] + params = { + k: v for k, v in params.arguments.items() if k not in to_ignore + } + + validate_parameter_constraints( + parameter_constraints, params, caller_name=func.__qualname__ + ) + return func(*args, **kwargs) + + return wrapper + + return decorator + + def _type_name(t): + """Convert type into human readable string.""" + module = t.__module__ + qualname = t.__qualname__ + if module == "builtins": + return qualname + elif t == Real: + return "float" + elif t == Integral: + return "int" + return f"{module}.{qualname}" + + class _Constraint(ABC): + """Base class for the constraint objects.""" + + def __init__(self): + self.hidden = False + + @abstractmethod + def is_satisfied_by(self, val): + """Whether or not a value satisfies the constraint. + + Parameters + ---------- + val : object + The value to check. + + Returns + ------- + is_satisfied : bool + Whether or not the constraint is satisfied by this value. + """ + + @abstractmethod + def __str__(self): + """A human readable representational string of the constraint.""" + + class _InstancesOf(_Constraint): + """Constraint representing instances of a given type. + + Parameters + ---------- + type : type + The valid type. + """ + + def __init__(self, type): + super().__init__() + self.type = type + + def is_satisfied_by(self, val): + return isinstance(val, self.type) + + def __str__(self): + return f"an instance of {_type_name(self.type)!r}" + + class _NoneConstraint(_Constraint): + """Constraint representing the None singleton.""" + + def is_satisfied_by(self, val): + return val is None + + def __str__(self): + return "None" + + class _NanConstraint(_Constraint): + """Constraint representing the indicator `np.nan`.""" + + def is_satisfied_by(self, val): + return isinstance(val, Real) and math.isnan(val) + + def __str__(self): + return "numpy.nan" + + class _PandasNAConstraint(_Constraint): + """Constraint representing the indicator `pd.NA`.""" + + def is_satisfied_by(self, val): + try: + import pandas as pd + + return isinstance(val, type(pd.NA)) and pd.isna(val) + except ImportError: + return False + + def __str__(self): + return "pandas.NA" + + class Options(_Constraint): + """Constraint representing a finite set of instances of a given type. + + Parameters + ---------- + type : type + + options : set + The set of valid scalars. + + deprecated : set or None, default=None + A subset of the `options` to mark as deprecated in the string + representation of the constraint. + """ + + def __init__(self, type, options, *, deprecated=None): + super().__init__() + self.type = type + self.options = options + self.deprecated = deprecated or set() + + if self.deprecated - self.options: + raise ValueError( + "The deprecated options must be a subset of the options." + ) + + def is_satisfied_by(self, val): + return isinstance(val, self.type) and val in self.options + + def _mark_if_deprecated(self, option): + """Add a deprecated mark to an option if needed.""" + option_str = f"{option!r}" + if option in self.deprecated: + option_str = f"{option_str} (deprecated)" + return option_str + + def __str__(self): + options_str = ( + f"{', '.join([self._mark_if_deprecated(o) for o in self.options])}" + ) + return f"a {_type_name(self.type)} among {{{options_str}}}" + + class StrOptions(Options): + """Constraint representing a finite set of strings. + + Parameters + ---------- + options : set of str + The set of valid strings. + + deprecated : set of str or None, default=None + A subset of the `options` to mark as deprecated in the string + representation of the constraint. + """ + + def __init__(self, options, *, deprecated=None): + super().__init__(type=str, options=options, deprecated=deprecated) + + class Interval(_Constraint): + """Constraint representing a typed interval. + + Parameters + ---------- + type : {numbers.Integral, numbers.Real} + The set of numbers in which to set the interval. + + left : float or int or None + The left bound of the interval. None means left bound is -∞. + + right : float, int or None + The right bound of the interval. None means right bound is +∞. + + closed : {"left", "right", "both", "neither"} + Whether the interval is open or closed. Possible choices are: + + - `"left"`: the interval is closed on the left and open on the right. + It is equivalent to the interval `[ left, right )`. + - `"right"`: the interval is closed on the right and open on the left. + It is equivalent to the interval `( left, right ]`. + - `"both"`: the interval is closed. + It is equivalent to the interval `[ left, right ]`. + - `"neither"`: the interval is open. + It is equivalent to the interval `( left, right )`. + + Notes + ----- + Setting a bound to `None` and setting the interval closed is valid. For + instance, strictly speaking, `Interval(Real, 0, None, closed="both")` + corresponds to `[0, +∞) U {+∞}`. + """ + + @validate_params( + { + "type": [type], + "left": [Integral, Real, None], + "right": [Integral, Real, None], + "closed": [StrOptions({"left", "right", "both", "neither"})], + } + ) + def __init__(self, type, left, right, *, closed): + super().__init__() + self.type = type + self.left = left + self.right = right + self.closed = closed + + self._check_params() + + def _check_params(self): + if self.type is Integral: + suffix = "for an interval over the integers." + if self.left is not None and not isinstance(self.left, Integral): + raise TypeError(f"Expecting left to be an int {suffix}") + if self.right is not None and not isinstance(self.right, Integral): + raise TypeError(f"Expecting right to be an int {suffix}") + if self.left is None and self.closed in ("left", "both"): + raise ValueError( + f"left can't be None when closed == {self.closed} {suffix}" + ) + if self.right is None and self.closed in ("right", "both"): + raise ValueError( + f"right can't be None when closed == {self.closed} {suffix}" + ) + + if ( + self.right is not None + and self.left is not None + and self.right <= self.left + ): + raise ValueError( + f"right can't be less than left. Got left={self.left} and " + f"right={self.right}" + ) + + def __contains__(self, val): + if np.isnan(val): + return False + + left_cmp = operator.lt if self.closed in ("left", "both") else operator.le + right_cmp = operator.gt if self.closed in ("right", "both") else operator.ge + + left = -np.inf if self.left is None else self.left + right = np.inf if self.right is None else self.right + + if left_cmp(val, left): + return False + if right_cmp(val, right): + return False + return True + + def is_satisfied_by(self, val): + if not isinstance(val, self.type): + return False + + return val in self + + def __str__(self): + type_str = "an int" if self.type is Integral else "a float" + left_bracket = "[" if self.closed in ("left", "both") else "(" + left_bound = "-inf" if self.left is None else self.left + right_bound = "inf" if self.right is None else self.right + right_bracket = "]" if self.closed in ("right", "both") else ")" + return ( + f"{type_str} in the range " + f"{left_bracket}{left_bound}, {right_bound}{right_bracket}" + ) + + class _ArrayLikes(_Constraint): + """Constraint representing array-likes""" + + def is_satisfied_by(self, val): + return _is_arraylike_not_scalar(val) + + def __str__(self): + return "an array-like" + + class _SparseMatrices(_Constraint): + """Constraint representing sparse matrices.""" + + def is_satisfied_by(self, val): + return issparse(val) + + def __str__(self): + return "a sparse matrix" + + class _Callables(_Constraint): + """Constraint representing callables.""" + + def is_satisfied_by(self, val): + return callable(val) + + def __str__(self): + return "a callable" + + class _RandomStates(_Constraint): + """Constraint representing random states. + + Convenience class for + [Interval(Integral, 0, 2**32 - 1, closed="both"), np.random.RandomState, None] + """ + + def __init__(self): + super().__init__() + self._constraints = [ + Interval(Integral, 0, 2**32 - 1, closed="both"), + _InstancesOf(np.random.RandomState), + _NoneConstraint(), + ] + + def is_satisfied_by(self, val): + return any(c.is_satisfied_by(val) for c in self._constraints) + + def __str__(self): + return ( + f"{', '.join([str(c) for c in self._constraints[:-1]])} or" + f" {self._constraints[-1]}" + ) + + class _Booleans(_Constraint): + """Constraint representing boolean likes. + + Convenience class for + [bool, np.bool_, Integral (deprecated)] + """ + + def __init__(self): + super().__init__() + self._constraints = [ + _InstancesOf(bool), + _InstancesOf(np.bool_), + _InstancesOf(Integral), + ] + + def is_satisfied_by(self, val): + # TODO(1.4) remove support for Integral. + if isinstance(val, Integral) and not isinstance(val, bool): + warnings.warn( + "Passing an int for a boolean parameter is deprecated in version" + " 1.2 and won't be supported anymore in version 1.4.", + FutureWarning, + ) + + return any(c.is_satisfied_by(val) for c in self._constraints) + + def __str__(self): + return ( + f"{', '.join([str(c) for c in self._constraints[:-1]])} or" + f" {self._constraints[-1]}" + ) + + class _VerboseHelper(_Constraint): + """Helper constraint for the verbose parameter. + + Convenience class for + [Interval(Integral, 0, None, closed="left"), bool, numpy.bool_] + """ + + def __init__(self): + super().__init__() + self._constraints = [ + Interval(Integral, 0, None, closed="left"), + _InstancesOf(bool), + _InstancesOf(np.bool_), + ] + + def is_satisfied_by(self, val): + return any(c.is_satisfied_by(val) for c in self._constraints) + + def __str__(self): + return ( + f"{', '.join([str(c) for c in self._constraints[:-1]])} or" + f" {self._constraints[-1]}" + ) + + class _MissingValues(_Constraint): + """Helper constraint for the `missing_values` parameters. + + Convenience for + [ + Integral, + Interval(Real, None, None, closed="both"), + str, + None, + _NanConstraint(), + _PandasNAConstraint(), + ] + """ + + def __init__(self): + super().__init__() + self._constraints = [ + _InstancesOf(Integral), + # we use an interval of Real to ignore np.nan that has its own + # constraint + Interval(Real, None, None, closed="both"), + _InstancesOf(str), + _NoneConstraint(), + _NanConstraint(), + _PandasNAConstraint(), + ] + + def is_satisfied_by(self, val): + return any(c.is_satisfied_by(val) for c in self._constraints) + + def __str__(self): + return ( + f"{', '.join([str(c) for c in self._constraints[:-1]])} or" + f" {self._constraints[-1]}" + ) + + class HasMethods(_Constraint): + """Constraint representing objects that expose specific methods. + + It is useful for parameters following a protocol and where we don't want to + impose an affiliation to a specific module or class. + + Parameters + ---------- + methods : str or list of str + The method(s) that the object is expected to expose. + """ + + @validate_params({"methods": [str, list]}) + def __init__(self, methods): + super().__init__() + if isinstance(methods, str): + methods = [methods] + self.methods = methods + + def is_satisfied_by(self, val): + return all(callable(getattr(val, method, None)) for method in self.methods) + + def __str__(self): + if len(self.methods) == 1: + methods = f"{self.methods[0]!r}" + else: + methods = ( + f"{', '.join([repr(m) for m in self.methods[:-1]])} and" + f" {self.methods[-1]!r}" + ) + return f"an object implementing {methods}" + + class _IterablesNotString(_Constraint): + """Constraint representing iterables that are not strings.""" + + def is_satisfied_by(self, val): + return isinstance(val, Iterable) and not isinstance(val, str) + + def __str__(self): + return "an iterable" + + class _CVObjects(_Constraint): + """Constraint representing cv objects. + + Convenient class for + [ + Interval(Integral, 2, None, closed="left"), + HasMethods(["split", "get_n_splits"]), + _IterablesNotString(), + None, + ] + """ + + def __init__(self): + super().__init__() + self._constraints = [ + Interval(Integral, 2, None, closed="left"), + HasMethods(["split", "get_n_splits"]), + _IterablesNotString(), + _NoneConstraint(), + ] + + def is_satisfied_by(self, val): + return any(c.is_satisfied_by(val) for c in self._constraints) + + def __str__(self): + return ( + f"{', '.join([str(c) for c in self._constraints[:-1]])} or" + f" {self._constraints[-1]}" + ) + + class Hidden: + """Class encapsulating a constraint not meant to be exposed to the user. + + Parameters + ---------- + constraint : str or _Constraint instance + The constraint to be used internally. + """ + + def __init__(self, constraint): + self.constraint = constraint + + def generate_invalid_param_val(constraint, constraints=None): + """Return a value that does not satisfy the constraint. + + Raises a NotImplementedError if there exists no invalid value for this + constraint. + + This is only useful for testing purpose. + + Parameters + ---------- + constraint : _Constraint instance + The constraint to generate a value for. + + constraints : list of _Constraint instances or None, default=None + The list of all constraints for this parameter. If None, the list only + containing `constraint` is used. + + Returns + ------- + val : object + A value that does not satisfy the constraint. + """ + if isinstance(constraint, StrOptions): + return f"not {' or '.join(constraint.options)}" + + if isinstance(constraint, _MissingValues): + return np.array([1, 2, 3]) + + if isinstance(constraint, _VerboseHelper): + return -1 + + if isinstance(constraint, HasMethods): + return type("HasNotMethods", (), {})() + + if isinstance(constraint, _IterablesNotString): + return "a string" + + if isinstance(constraint, _CVObjects): + return "not a cv object" + + if not isinstance(constraint, Interval): + raise NotImplementedError + + # constraint is an interval + constraints = [constraint] if constraints is None else constraints + return _generate_invalid_param_val_interval(constraint, constraints) + + def _generate_invalid_param_val_interval(interval, constraints): + """Return a value that does not satisfy an interval constraint. + + Generating an invalid value for an integer interval depends on the other + constraints since an int is a real, meaning that it can be valid for a real + interval. Assumes that there can be at most 2 interval constraints: one integer + interval and/or one real interval. + + This is only useful for testing purpose. + + Parameters + ---------- + interval : Interval instance + The interval to generate a value for. + + constraints : list of _Constraint instances + The list of all constraints for this parameter. + + Returns + ------- + val : object + A value that does not satisfy the interval constraint. + """ + if interval.type is Real: + # generate a non-integer value such that it can't be valid even if there's + # also an integer interval constraint. + if interval.left is None and interval.right is None: + if interval.closed in ("left", "neither"): + return np.inf + elif interval.closed in ("right", "neither"): + return -np.inf + else: + raise NotImplementedError + + if interval.left is not None: + return np.floor(interval.left) - 0.5 + else: # right is not None + return np.ceil(interval.right) + 0.5 + + else: # interval.type is Integral + if interval.left is None and interval.right is None: + raise NotImplementedError + + # We need to check if there's also a real interval constraint to generate a + # value that is not valid for any of the 2 interval constraints. + real_intervals = [ + i for i in constraints if isinstance(i, Interval) and i.type is Real + ] + real_interval = real_intervals[0] if real_intervals else None + + if real_interval is None: + # Only the integer interval constraint -> easy + if interval.left is not None: + return interval.left - 1 + else: # interval.right is not None + return interval.right + 1 + + # There's also a real interval constraint. Try to find a value left to both + # or right to both or in between them. + + # redefine left and right bounds to be smallest and largest valid integers + # in both intervals. + int_left = interval.left + if int_left is not None and interval.closed in ("right", "neither"): + int_left = int_left + 1 + + int_right = interval.right + if int_right is not None and interval.closed in ("left", "neither"): + int_right = int_right - 1 + + real_left = real_interval.left + if real_interval.left is not None: + real_left = int(np.ceil(real_interval.left)) + if real_interval.closed in ("right", "neither"): + real_left = real_left + 1 + + real_right = real_interval.right + if real_interval.right is not None: + real_right = int(np.floor(real_interval.right)) + if real_interval.closed in ("left", "neither"): + real_right = real_right - 1 + + if int_left is not None and real_left is not None: + # there exists an int left to both intervals + return min(int_left, real_left) - 1 + + if int_right is not None and real_right is not None: + # there exists an int right to both intervals + return max(int_right, real_right) + 1 + + if int_left is not None: + if real_right is not None and int_left - real_right >= 2: + # there exists an int between the 2 intervals + return int_left - 1 + else: + raise NotImplementedError + else: # int_right is not None + if real_left is not None and real_left - int_right >= 2: + # there exists an int between the 2 intervals + return int_right + 1 + else: + raise NotImplementedError + + def generate_valid_param(constraint): + """Return a value that does satisfy a constraint. + + This is only useful for testing purpose. + + Parameters + ---------- + constraint : Constraint instance + The constraint to generate a value for. + + Returns + ------- + val : object + A value that does satisfy the constraint. + """ + if isinstance(constraint, _ArrayLikes): + return np.array([1, 2, 3]) + + if isinstance(constraint, _SparseMatrices): + return csr_matrix([[0, 1], [1, 0]]) + + if isinstance(constraint, _RandomStates): + return np.random.RandomState(42) + + if isinstance(constraint, _Callables): + return lambda x: x + + if isinstance(constraint, _NoneConstraint): + return None + + if isinstance(constraint, _InstancesOf): + return constraint.type() + + if isinstance(constraint, _Booleans): + return True + + if isinstance(constraint, _VerboseHelper): + return 1 + + if isinstance(constraint, _MissingValues): + return np.nan + + if isinstance(constraint, HasMethods): + return type( + "ValidHasMethods", + (), + {m: lambda self: None for m in constraint.methods}, + )() + + if isinstance(constraint, _IterablesNotString): + return [1, 2, 3] + + if isinstance(constraint, _CVObjects): + return 5 + + if isinstance(constraint, Options): # includes StrOptions + for option in constraint.options: + return option + + if isinstance(constraint, Interval): + interval = constraint + if interval.left is None and interval.right is None: + return 0 + elif interval.left is None: + return interval.right - 1 + elif interval.right is None: + return interval.left + 1 + else: + if interval.type is Real: + return (interval.left + interval.right) / 2 + else: + return interval.left + 1 + + raise ValueError(f"Unknown constraint type: {constraint}") + +else: + from sklearn.utils._param_validation import generate_invalid_param_val # noqa + from sklearn.utils._param_validation import generate_valid_param # noqa + from sklearn.utils._param_validation import validate_parameter_constraints # noqa + from sklearn.utils._param_validation import ( + HasMethods, + Hidden, + Interval, + Options, + StrOptions, + _ArrayLikes, + _Booleans, + _Callables, + _CVObjects, + _InstancesOf, + _IterablesNotString, + _MissingValues, + _NoneConstraint, + _PandasNAConstraint, + _RandomStates, + _SparseMatrices, + _VerboseHelper, + make_constraint, + validate_params, + ) diff --git a/imblearn/utils/_validation.py b/imblearn/utils/_validation.py index 007f90e02..b8df4b926 100644 --- a/imblearn/utils/_validation.py +++ b/imblearn/utils/_validation.py @@ -109,14 +109,8 @@ def check_neighbors_object(nn_name, nn_object, additional_neighbor=0): """ if isinstance(nn_object, Integral): return NearestNeighbors(n_neighbors=nn_object + additional_neighbor) - elif _is_neighbors_object(nn_object): - return clone(nn_object) - else: - raise ValueError( - f"{nn_name} must be an interger or an object compatible with the " - "KNeighborsMixin API of scikit-learn (i.e. implementing `kneighbors` " - "method)." - ) + # _is_neighbors_object(nn_object) + return clone(nn_object) def _count_class_sample(y): diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index 871f7ac26..4c4c72741 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -23,18 +23,25 @@ ) from sklearn.exceptions import SkipTestWarning from sklearn.preprocessing import label_binarize +from sklearn.utils._tags import _safe_tags from sklearn.utils._testing import ( assert_allclose, assert_array_equal, assert_raises_regex, + raises, +) +from sklearn.utils.estimator_checks import ( + _enforce_estimator_tags_y, + _get_check_estimator_ids, + _maybe_mark_xfail, ) -from sklearn.utils.estimator_checks import _get_check_estimator_ids, _maybe_mark_xfail from sklearn.utils.fixes import parse_version from sklearn.utils.multiclass import type_of_target from imblearn.datasets import make_imbalance from imblearn.over_sampling.base import BaseOverSampler from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler +from imblearn.utils._param_validation import generate_invalid_param_val, make_constraint sklearn_version = parse_version(sklearn.__version__) @@ -471,3 +478,92 @@ def check_classifiers_with_encoded_labels(name, classifier_orig): assert set(classifier.classes_) == set(y.cat.categories.tolist()) y_pred = classifier.predict(df) assert set(y_pred) == set(y.cat.categories.tolist()) + + +def check_param_validation(name, estimator_orig): + # Check that an informative error is raised when the value of a constructor + # parameter does not have an appropriate type or value. + rng = np.random.RandomState(0) + X = rng.uniform(size=(20, 5)) + y = rng.randint(0, 2, size=20) + y = _enforce_estimator_tags_y(estimator_orig, y) + + estimator_params = estimator_orig.get_params(deep=False).keys() + + # check that there is a constraint for each parameter + if estimator_params: + validation_params = estimator_orig._parameter_constraints.keys() + unexpected_params = set(validation_params) - set(estimator_params) + missing_params = set(estimator_params) - set(validation_params) + err_msg = ( + f"Mismatch between _parameter_constraints and the parameters of {name}." + f"\nConsider the unexpected parameters {unexpected_params} and expected but" + f" missing parameters {missing_params}" + ) + assert validation_params == estimator_params, err_msg + + # this object does not have a valid type for sure for all params + param_with_bad_type = type("BadType", (), {})() + + fit_methods = ["fit", "partial_fit", "fit_transform", "fit_predict", "fit_resample"] + + for param_name in estimator_params: + constraints = estimator_orig._parameter_constraints[param_name] + + if constraints == "no_validation": + # This parameter is not validated + continue # pragma: no cover + + match = rf"The '{param_name}' parameter of {name} must be .* Got .* instead." + err_msg = ( + f"{name} does not raise an informative error message when the " + f"parameter {param_name} does not have a valid type or value." + ) + + estimator = clone(estimator_orig) + + # First, check that the error is raised if param doesn't match any valid type. + estimator.set_params(**{param_name: param_with_bad_type}) + + for method in fit_methods: + if not hasattr(estimator, method): + # the method is not accessible with the current set of parameters + continue + + with raises(ValueError, match=match, err_msg=err_msg): + if any( + isinstance(X_type, str) and X_type.endswith("labels") + for X_type in _safe_tags(estimator, key="X_types") + ): + # The estimator is a label transformer and take only `y` + getattr(estimator, method)(y) # pragma: no cover + else: + getattr(estimator, method)(X, y) + + # Then, for constraints that are more than a type constraint, check that the + # error is raised if param does match a valid type but does not match any valid + # value for this type. + constraints = [make_constraint(constraint) for constraint in constraints] + + for constraint in constraints: + try: + bad_value = generate_invalid_param_val(constraint, constraints) + except NotImplementedError: + continue + + estimator.set_params(**{param_name: bad_value}) + + for method in fit_methods: + if not hasattr(estimator, method): + # the method is not accessible with the current set of parameters + continue + + with raises(ValueError, match=match, err_msg=err_msg): + if any( + X_type.endswith("labels") + for X_type in _safe_tags(estimator, key="X_types") + ): + # The estimator is a label transformer and take only `y` + getattr(estimator, method)(y) # pragma: no cover + else: + getattr(estimator, method)(X, y) diff --git a/imblearn/utils/fixes.py b/imblearn/utils/fixes.py index 94e6825b5..c2db6adc4 100644 --- a/imblearn/utils/fixes.py +++ b/imblearn/utils/fixes.py @@ -5,11 +5,14 @@ which the fix is no longer needed. """ +import numpy as np import scipy import scipy.stats +import sklearn from sklearn.utils.fixes import parse_version sp_version = parse_version(scipy.__version__) +sklearn_version = parse_version(sklearn.__version__) # TODO: Remove when SciPy 1.9 is the minimum supported version @@ -17,3 +20,14 @@ def _mode(a, axis=0): if sp_version >= parse_version("1.9.0"): return scipy.stats.mode(a, axis=axis, keepdims=True) return scipy.stats.mode(a, axis=axis) + + +# TODO: Remove when scikit-learn 1.1 is the minimum supported version +if sklearn_version >= parse_version("1.1"): + from sklearn.utils.validation import _is_arraylike_not_scalar +else: + from sklearn.utils.validation import _is_arraylike + + def _is_arraylike_not_scalar(array): + """Return True if array is array-like and not a scalar""" + return _is_arraylike(array) and not np.isscalar(array) diff --git a/imblearn/utils/testing.py b/imblearn/utils/testing.py index 357af7283..8c19d6101 100644 --- a/imblearn/utils/testing.py +++ b/imblearn/utils/testing.py @@ -10,6 +10,7 @@ from operator import itemgetter from pathlib import Path +import numpy as np from scipy import sparse from sklearn.base import BaseEstimator from sklearn.neighbors import KDTree @@ -143,8 +144,14 @@ def kneighbors_graph(X=None, n_neighbors=None, mode="connectivity"): class _CustomClusterer(BaseEstimator): """Class that mimics a cluster that does not expose `cluster_centers_`.""" - def __init__(self, n_clusters=1): + def __init__(self, n_clusters=1, expose_cluster_centers=True): self.n_clusters = n_clusters + self.expose_cluster_centers = expose_cluster_centers def fit(self, X, y=None): + if self.expose_cluster_centers: + self.cluster_centers_ = np.random.randn(self.n_clusters, X.shape[1]) return self + + def predict(self, X): + return np.zeros(len(X), dtype=int) diff --git a/imblearn/utils/tests/test_estimator_checks.py b/imblearn/utils/tests/test_estimator_checks.py index f8ebc4701..dbc337dd1 100644 --- a/imblearn/utils/tests/test_estimator_checks.py +++ b/imblearn/utils/tests/test_estimator_checks.py @@ -63,6 +63,8 @@ def fit(self, X, y): class NotPreservingDtypeSampler(BaseSampler): _sampling_type = "bypass" + _parameter_constraints: dict = {"sampling_strategy": "no_validation"} + def _fit_resample(self, X, y): return X.astype(np.float64), y.astype(np.int64) diff --git a/imblearn/utils/tests/test_param_validation.py b/imblearn/utils/tests/test_param_validation.py new file mode 100644 index 000000000..dae58a790 --- /dev/null +++ b/imblearn/utils/tests/test_param_validation.py @@ -0,0 +1,646 @@ +"""This is a copy of sklearn/utils/tests/test_param_validation.py. It can be +removed when we support scikit-learn >= 1.2. +""" +from numbers import Integral, Real + +import numpy as np +import pytest +from scipy.sparse import csr_matrix +from sklearn.base import BaseEstimator +from sklearn.model_selection import LeaveOneOut +from sklearn.utils import deprecated + +from imblearn.base import _ParamsValidationMixin +from imblearn.utils._param_validation import ( + HasMethods, + Hidden, + Interval, + Options, + StrOptions, + _ArrayLikes, + _Booleans, + _Callables, + _CVObjects, + _InstancesOf, + _IterablesNotString, + _MissingValues, + _NoneConstraint, + _PandasNAConstraint, + _RandomStates, + _SparseMatrices, + _VerboseHelper, + generate_invalid_param_val, + generate_valid_param, + make_constraint, + validate_params, +) + + +# Some helpers for the tests +@validate_params({"a": [Real], "b": [Real], "c": [Real], "d": [Real]}) +def _func(a, b=0, *args, c, d=0, **kwargs): + """A function to test the validation of functions.""" + + +class _Class: + """A class to test the _InstancesOf constraint and the validation of methods.""" + + @validate_params({"a": [Real]}) + def _method(self, a): + """A validated method""" + + @deprecated() + @validate_params({"a": [Real]}) + def _deprecated_method(self, a): + """A deprecated validated method""" + + +class _Estimator(BaseEstimator, _ParamsValidationMixin): + """An estimator to test the validation of estimator parameters.""" + + _parameter_constraints: dict = {"a": [Real]} + + def __init__(self, a): + self.a = a + + def fit(self, X=None, y=None): + self._validate_params() + + +@pytest.mark.parametrize("interval_type", [Integral, Real]) +def test_interval_range(interval_type): + """Check the range of values depending on closed.""" + interval = Interval(interval_type, -2, 2, closed="left") + assert -2 in interval and 2 not in interval + + interval = Interval(interval_type, -2, 2, closed="right") + assert -2 not in interval and 2 in interval + + interval = Interval(interval_type, -2, 2, closed="both") + assert -2 in interval and 2 in interval + + interval = Interval(interval_type, -2, 2, closed="neither") + assert -2 not in interval and 2 not in interval + + +def test_interval_inf_in_bounds(): + """Check that inf is included iff a bound is closed and set to None. + + Only valid for real intervals. + """ + interval = Interval(Real, 0, None, closed="right") + assert np.inf in interval + + interval = Interval(Real, None, 0, closed="left") + assert -np.inf in interval + + interval = Interval(Real, None, None, closed="neither") + assert np.inf not in interval + assert -np.inf not in interval + + +@pytest.mark.parametrize( + "interval", + [Interval(Real, 0, 1, closed="left"), Interval(Real, None, None, closed="both")], +) +def test_nan_not_in_interval(interval): + """Check that np.nan is not in any interval.""" + assert np.nan not in interval + + +@pytest.mark.parametrize( + "params, error, match", + [ + ( + {"type": Integral, "left": 1.0, "right": 2, "closed": "both"}, + TypeError, + r"Expecting left to be an int for an interval over the integers", + ), + ( + {"type": Integral, "left": 1, "right": 2.0, "closed": "neither"}, + TypeError, + "Expecting right to be an int for an interval over the integers", + ), + ( + {"type": Integral, "left": None, "right": 0, "closed": "left"}, + ValueError, + r"left can't be None when closed == left", + ), + ( + {"type": Integral, "left": 0, "right": None, "closed": "right"}, + ValueError, + r"right can't be None when closed == right", + ), + ( + {"type": Integral, "left": 1, "right": -1, "closed": "both"}, + ValueError, + r"right can't be less than left", + ), + ], +) +def test_interval_errors(params, error, match): + """Check that informative errors are raised for invalid combination of parameters""" + with pytest.raises(error, match=match): + Interval(**params) + + +def test_stroptions(): + """Sanity check for the StrOptions constraint""" + options = StrOptions({"a", "b", "c"}, deprecated={"c"}) + assert options.is_satisfied_by("a") + assert options.is_satisfied_by("c") + assert not options.is_satisfied_by("d") + + assert "'c' (deprecated)" in str(options) + + +def test_options(): + """Sanity check for the Options constraint""" + options = Options(Real, {-0.5, 0.5, np.inf}, deprecated={-0.5}) + assert options.is_satisfied_by(-0.5) + assert options.is_satisfied_by(np.inf) + assert not options.is_satisfied_by(1.23) + + assert "-0.5 (deprecated)" in str(options) + + +@pytest.mark.parametrize( + "type, expected_type_name", + [ + (int, "int"), + (Integral, "int"), + (Real, "float"), + (np.ndarray, "numpy.ndarray"), + ], +) +def test_instances_of_type_human_readable(type, expected_type_name): + """Check the string representation of the _InstancesOf constraint.""" + constraint = _InstancesOf(type) + assert str(constraint) == f"an instance of '{expected_type_name}'" + + +def test_hasmethods(): + """Check the HasMethods constraint.""" + constraint = HasMethods(["a", "b"]) + + class _Good: + def a(self): + pass # pragma: no cover + + def b(self): + pass # pragma: no cover + + class _Bad: + def a(self): + pass # pragma: no cover + + assert constraint.is_satisfied_by(_Good()) + assert not constraint.is_satisfied_by(_Bad()) + assert str(constraint) == "an object implementing 'a' and 'b'" + + +@pytest.mark.parametrize( + "constraint", + [ + Interval(Real, None, 0, closed="left"), + Interval(Real, 0, None, closed="left"), + Interval(Real, None, None, closed="neither"), + StrOptions({"a", "b", "c"}), + _MissingValues(), + _VerboseHelper(), + HasMethods("fit"), + _IterablesNotString(), + _CVObjects(), + ], +) +def test_generate_invalid_param_val(constraint): + """Check that the value generated does not satisfy the constraint""" + bad_value = generate_invalid_param_val(constraint) + assert not constraint.is_satisfied_by(bad_value) + + +@pytest.mark.parametrize( + "integer_interval, real_interval", + [ + ( + Interval(Integral, None, 3, closed="right"), + Interval(Real, -5, 5, closed="both"), + ), + ( + Interval(Integral, None, 3, closed="right"), + Interval(Real, -5, 5, closed="neither"), + ), + ( + Interval(Integral, None, 3, closed="right"), + Interval(Real, 4, 5, closed="both"), + ), + ( + Interval(Integral, None, 3, closed="right"), + Interval(Real, 5, None, closed="left"), + ), + ( + Interval(Integral, None, 3, closed="right"), + Interval(Real, 4, None, closed="neither"), + ), + ( + Interval(Integral, 3, None, closed="left"), + Interval(Real, -5, 5, closed="both"), + ), + ( + Interval(Integral, 3, None, closed="left"), + Interval(Real, -5, 5, closed="neither"), + ), + ( + Interval(Integral, 3, None, closed="left"), + Interval(Real, 1, 2, closed="both"), + ), + ( + Interval(Integral, 3, None, closed="left"), + Interval(Real, None, -5, closed="left"), + ), + ( + Interval(Integral, 3, None, closed="left"), + Interval(Real, None, -4, closed="neither"), + ), + ( + Interval(Integral, -5, 5, closed="both"), + Interval(Real, None, 1, closed="right"), + ), + ( + Interval(Integral, -5, 5, closed="both"), + Interval(Real, 1, None, closed="left"), + ), + ( + Interval(Integral, -5, 5, closed="both"), + Interval(Real, -10, -4, closed="neither"), + ), + ( + Interval(Integral, -5, 5, closed="both"), + Interval(Real, -10, -4, closed="right"), + ), + ( + Interval(Integral, -5, 5, closed="neither"), + Interval(Real, 6, 10, closed="neither"), + ), + ( + Interval(Integral, -5, 5, closed="neither"), + Interval(Real, 6, 10, closed="left"), + ), + ( + Interval(Integral, 2, None, closed="left"), + Interval(Real, 0, 1, closed="both"), + ), + ( + Interval(Integral, 1, None, closed="left"), + Interval(Real, 0, 1, closed="both"), + ), + ], +) +def test_generate_invalid_param_val_2_intervals(integer_interval, real_interval): + """Check that the value generated for an interval constraint does not satisfy any of + the interval constraints. + """ + bad_value = generate_invalid_param_val( + real_interval, constraints=[real_interval, integer_interval] + ) + assert not real_interval.is_satisfied_by(bad_value) + assert not integer_interval.is_satisfied_by(bad_value) + + bad_value = generate_invalid_param_val( + integer_interval, constraints=[real_interval, integer_interval] + ) + assert not real_interval.is_satisfied_by(bad_value) + assert not integer_interval.is_satisfied_by(bad_value) + + +@pytest.mark.parametrize( + "constraints", + [ + [_ArrayLikes()], + [_InstancesOf(list)], + [_Callables()], + [_NoneConstraint()], + [_RandomStates()], + [_SparseMatrices()], + [_Booleans()], + [Interval(Real, None, None, closed="both")], + [ + Interval(Integral, 0, None, closed="left"), + Interval(Real, None, 0, closed="neither"), + ], + ], +) +def test_generate_invalid_param_val_all_valid(constraints): + """Check that the function raises NotImplementedError when there's no invalid value + for the constraint. + """ + with pytest.raises(NotImplementedError): + generate_invalid_param_val(constraints[0], constraints=constraints) + + +@pytest.mark.parametrize( + "constraint", + [ + _ArrayLikes(), + _Callables(), + _InstancesOf(list), + _NoneConstraint(), + _RandomStates(), + _SparseMatrices(), + _Booleans(), + _VerboseHelper(), + _MissingValues(), + StrOptions({"a", "b", "c"}), + Options(Integral, {1, 2, 3}), + Interval(Integral, None, None, closed="neither"), + Interval(Integral, 0, 10, closed="neither"), + Interval(Integral, 0, None, closed="neither"), + Interval(Integral, None, 0, closed="neither"), + Interval(Real, 0, 1, closed="neither"), + Interval(Real, 0, None, closed="both"), + Interval(Real, None, 0, closed="right"), + HasMethods("fit"), + _IterablesNotString(), + _CVObjects(), + ], +) +def test_generate_valid_param(constraint): + """Check that the value generated does satisfy the constraint.""" + value = generate_valid_param(constraint) + assert constraint.is_satisfied_by(value) + + +@pytest.mark.parametrize( + "constraint_declaration, value", + [ + (Interval(Real, 0, 1, closed="both"), 0.42), + (Interval(Integral, 0, None, closed="neither"), 42), + (StrOptions({"a", "b", "c"}), "b"), + (Options(type, {np.float32, np.float64}), np.float64), + (callable, lambda x: x + 1), + (None, None), + ("array-like", [[1, 2], [3, 4]]), + ("array-like", np.array([[1, 2], [3, 4]])), + ("sparse matrix", csr_matrix([[1, 2], [3, 4]])), + ("random_state", 0), + ("random_state", np.random.RandomState(0)), + ("random_state", None), + (_Class, _Class()), + (int, 1), + (Real, 0.5), + ("boolean", False), + ("verbose", 1), + ("missing_values", -1), + ("missing_values", -1.0), + ("missing_values", None), + ("missing_values", float("nan")), + ("missing_values", np.nan), + ("missing_values", "missing"), + (HasMethods("fit"), _Estimator(a=0)), + ("cv_object", 5), + ], +) +def test_is_satisfied_by(constraint_declaration, value): + """Sanity check for the is_satisfied_by method""" + constraint = make_constraint(constraint_declaration) + assert constraint.is_satisfied_by(value) + + +@pytest.mark.parametrize( + "constraint_declaration, expected_constraint_class", + [ + (Interval(Real, 0, 1, closed="both"), Interval), + (StrOptions({"option1", "option2"}), StrOptions), + (Options(Real, {0.42, 1.23}), Options), + ("array-like", _ArrayLikes), + ("sparse matrix", _SparseMatrices), + ("random_state", _RandomStates), + (None, _NoneConstraint), + (callable, _Callables), + (int, _InstancesOf), + ("boolean", _Booleans), + ("verbose", _VerboseHelper), + ("missing_values", _MissingValues), + (HasMethods("fit"), HasMethods), + ("cv_object", _CVObjects), + ], +) +def test_make_constraint(constraint_declaration, expected_constraint_class): + """Check that make_constraint dispaches to the appropriate constraint class""" + constraint = make_constraint(constraint_declaration) + assert constraint.__class__ is expected_constraint_class + + +def test_make_constraint_unknown(): + """Check that an informative error is raised when an unknown constraint is passed""" + with pytest.raises(ValueError, match="Unknown constraint"): + make_constraint("not a valid constraint") + + +def test_validate_params(): + """Check that validate_params works no matter how the arguments are passed""" + with pytest.raises(ValueError, match="The 'a' parameter of _func must be"): + _func("wrong", c=1) + + with pytest.raises(ValueError, match="The 'b' parameter of _func must be"): + _func(*[1, "wrong"], c=1) + + with pytest.raises(ValueError, match="The 'c' parameter of _func must be"): + _func(1, **{"c": "wrong"}) + + with pytest.raises(ValueError, match="The 'd' parameter of _func must be"): + _func(1, c=1, d="wrong") + + # check in the presence of extra positional and keyword args + with pytest.raises(ValueError, match="The 'b' parameter of _func must be"): + _func(0, *["wrong", 2, 3], c=4, **{"e": 5}) + + with pytest.raises(ValueError, match="The 'c' parameter of _func must be"): + _func(0, *[1, 2, 3], c="four", **{"e": 5}) + + +def test_validate_params_missing_params(): + """Check that no error is raised when there are parameters without + constraints + """ + + @validate_params({"a": [int]}) + def func(a, b): + pass + + func(1, 2) + + +def test_decorate_validated_function(): + """Check that validate_params functions can be decorated""" + decorated_function = deprecated()(_func) + + with pytest.warns(FutureWarning, match="Function _func is deprecated"): + decorated_function(1, 2, c=3) + + # outer decorator does not interfer with validation + with pytest.warns(FutureWarning, match="Function _func is deprecated"): + with pytest.raises(ValueError, match=r"The 'c' parameter of _func must be"): + decorated_function(1, 2, c="wrong") + + +def test_validate_params_method(): + """Check that validate_params works with methods""" + with pytest.raises(ValueError, match="The 'a' parameter of _Class._method must be"): + _Class()._method("wrong") + + # validated method can be decorated + with pytest.warns(FutureWarning, match="Function _deprecated_method is deprecated"): + with pytest.raises( + ValueError, match="The 'a' parameter of _Class._deprecated_method must be" + ): + _Class()._deprecated_method("wrong") + + +def test_validate_params_estimator(): + """Check that validate_params works with Estimator instances""" + # no validation in init + est = _Estimator("wrong") + + with pytest.raises(ValueError, match="The 'a' parameter of _Estimator must be"): + est.fit() + + +def test_stroptions_deprecated_subset(): + """Check that the deprecated parameter must be a subset of options.""" + with pytest.raises(ValueError, match="deprecated options must be a subset"): + StrOptions({"a", "b", "c"}, deprecated={"a", "d"}) + + +def test_hidden_constraint(): + """Check that internal constraints are not exposed in the error message.""" + + @validate_params({"param": [Hidden(list), dict]}) + def f(param): + pass + + # list and dict are valid params + f({"a": 1, "b": 2, "c": 3}) + f([1, 2, 3]) + + with pytest.raises(ValueError, match="The 'param' parameter") as exc_info: + f(param="bad") + + # the list option is not exposed in the error message + err_msg = str(exc_info.value) + assert "an instance of 'dict'" in err_msg + assert "an instance of 'list'" not in err_msg + + +def test_hidden_stroptions(): + """Check that we can have 2 StrOptions constraints, one being hidden.""" + + @validate_params({"param": [StrOptions({"auto"}), Hidden(StrOptions({"warn"}))]}) + def f(param): + pass + + # "auto" and "warn" are valid params + f("auto") + f("warn") + + with pytest.raises(ValueError, match="The 'param' parameter") as exc_info: + f(param="bad") + + # the "warn" option is not exposed in the error message + err_msg = str(exc_info.value) + assert "auto" in err_msg + assert "warn" not in err_msg + + +def test_validate_params_set_param_constraints_attribute(): + """Check that the validate_params decorator properly sets the parameter constraints + as attribute of the decorated function/method. + """ + assert hasattr(_func, "_skl_parameter_constraints") + assert hasattr(_Class()._method, "_skl_parameter_constraints") + + +def test_boolean_constraint_deprecated_int(): + """Check that validate_params raise a deprecation message but still passes + validation when using an int for a parameter accepting a boolean. + """ + + @validate_params({"param": ["boolean"]}) + def f(param): + pass + + # True/False and np.bool_(True/False) are valid params + f(True) + f(np.bool_(False)) + + # an int is also valid but deprecated + with pytest.warns( + FutureWarning, match="Passing an int for a boolean parameter is deprecated" + ): + f(1) + + +def test_no_validation(): + """Check that validation can be skipped for a parameter.""" + + @validate_params({"param1": [int, None], "param2": "no_validation"}) + def f(param1=None, param2=None): + pass + + # param1 is validated + with pytest.raises(ValueError, match="The 'param1' parameter"): + f(param1="wrong") + + # param2 is not validated: any type is valid. + class SomeType: + pass + + f(param2=SomeType) + f(param2=SomeType()) + + +def test_pandas_na_constraint_with_pd_na(): + """Add a specific test for checking support for `pandas.NA`.""" + pd = pytest.importorskip("pandas") + + na_constraint = _PandasNAConstraint() + assert na_constraint.is_satisfied_by(pd.NA) + assert not na_constraint.is_satisfied_by(np.array([1, 2, 3])) + + +def test_iterable_not_string(): + """Check that a string does not satisfy the _IterableNotString constraint.""" + constraint = _IterablesNotString() + assert constraint.is_satisfied_by([1, 2, 3]) + assert constraint.is_satisfied_by(range(10)) + assert not constraint.is_satisfied_by("some string") + + +def test_cv_objects(): + """Check that the _CVObjects constraint accepts all current ways + to pass cv objects.""" + constraint = _CVObjects() + assert constraint.is_satisfied_by(5) + assert constraint.is_satisfied_by(LeaveOneOut()) + assert constraint.is_satisfied_by([([1, 2], [3, 4]), ([3, 4], [1, 2])]) + assert constraint.is_satisfied_by(None) + assert not constraint.is_satisfied_by("not a CV object") + + +def test_third_party_estimator(): + """Check that the validation from a scikit-learn estimator inherited by a third + party estimator does not impose a match between the dict of constraints and the + parameters of the estimator. + """ + + class ThirdPartyEstimator(_Estimator): + def __init__(self, b): + self.b = b + super().__init__(a=0) + + def fit(self, X=None, y=None): + super().fit(X, y) + + # does not raise, even though "b" is not in the constraints dict and "a" is not + # a parameter of the estimator. + ThirdPartyEstimator(b=0).fit() diff --git a/imblearn/utils/tests/test_validation.py b/imblearn/utils/tests/test_validation.py index 6a40ca171..587b5e278 100644 --- a/imblearn/utils/tests/test_validation.py +++ b/imblearn/utils/tests/test_validation.py @@ -7,6 +7,7 @@ import numpy as np import pytest +from sklearn.cluster import KMeans from sklearn.neighbors import NearestNeighbors from sklearn.neighbors._base import KNeighborsMixin from sklearn.utils._testing import assert_array_equal @@ -16,7 +17,11 @@ check_sampling_strategy, check_target_type, ) -from imblearn.utils._validation import ArraysTransformer, _deprecate_positional_args +from imblearn.utils._validation import ( + ArraysTransformer, + _deprecate_positional_args, + _is_neighbors_object, +) from imblearn.utils.testing import _CustomNearestNeighbors multiclass_target = np.array([1] * 50 + [2] * 100 + [3] * 25) @@ -38,13 +43,6 @@ def test_check_neighbors_object(): estimator = _CustomNearestNeighbors() estimator_cloned = check_neighbors_object(name, estimator) assert isinstance(estimator_cloned, _CustomNearestNeighbors) - n_neighbors = "rnd" - err_msg = ( - "n_neighbors must be an interger or an object compatible with the " - "KNeighborsMixin API of scikit-learn" - ) - with pytest.raises(ValueError, match=err_msg): - check_neighbors_object(name, n_neighbors) @pytest.mark.parametrize( @@ -383,3 +381,10 @@ def f3(a, *, b, c=1, d=1): with pytest.warns(FutureWarning, match=r"Pass b=2 as keyword args"): f3(1, 2) + + +@pytest.mark.parametrize( + "estimator, is_neighbor_estimator", [(NearestNeighbors(), True), (KMeans(), False)] +) +def test_is_neighbors_object(estimator, is_neighbor_estimator): + assert _is_neighbors_object(estimator) == is_neighbor_estimator