Skip to content

MAINT add parameter validation framework #955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Dec 4, 2022
3 changes: 3 additions & 0 deletions doc/whats_new/v0.10.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ Compatibility
- Maintenance release for be compatible with scikit-learn >= 1.0.2.
:pr:`946`, :pr:`947`, :pr:`949` by :user:`Guillaume Lemaitre <glemaitre>`.

- Add support for automatic parameters validation as in scikit-learn >= 1.2.
:pr:`955` by :user:`Guillaume Lemaitre <glemaitre>`.

Deprecation
...........

Expand Down
77 changes: 76 additions & 1 deletion imblearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sklearn.utils.multiclass import check_classification_targets

from .utils import check_sampling_strategy, check_target_type
from .utils._param_validation import validate_parameter_constraints
from .utils._validation import ArraysTransformer


Expand Down Expand Up @@ -113,7 +114,26 @@ def _fit_resample(self, X, y):
pass


class BaseSampler(SamplerMixin):
class _ParamsValidationMixin:
"""Mixin class to validate parameters."""

def _validate_params(self):
"""Validate types and values of constructor parameters.

The expected type and values must be defined in the `_parameter_constraints`
class attribute, which is a dictionary `param_name: list of constraints`. See
the docstring of `validate_parameter_constraints` for a description of the
accepted constraints.
"""
if hasattr(self, "_parameter_constraints"):
validate_parameter_constraints(
self._parameter_constraints,
self.get_params(deep=False),
caller_name=self.__class__.__name__,
)


class BaseSampler(SamplerMixin, _ParamsValidationMixin):
"""Base class for sampling algorithms.

Warning: This class should not be used directly. Use the derive classes
Expand All @@ -130,6 +150,52 @@ def _check_X_y(self, X, y, accept_sparse=None):
X, y = self._validate_data(X, y, reset=True, accept_sparse=accept_sparse)
return X, y, binarize_y

def fit(self, X, y):
"""Check inputs and statistics of the sampler.

You should use ``fit_resample`` in all cases.

Parameters
----------
X : {array-like, dataframe, sparse matrix} of shape \
(n_samples, n_features)
Data array.

y : array-like of shape (n_samples,)
Target array.

Returns
-------
self : object
Return the instance itself.
"""
self._validate_params()
return super().fit(X, y)

def fit_resample(self, X, y):
"""Resample the dataset.

Parameters
----------
X : {array-like, dataframe, sparse matrix} of shape \
(n_samples, n_features)
Matrix containing the data which have to be sampled.

y : array-like of shape (n_samples,)
Corresponding label for each sample in X.

Returns
-------
X_resampled : {array-like, dataframe, sparse matrix} of shape \
(n_samples_new, n_features)
The array containing the resampled data.

y_resampled : array-like of shape (n_samples_new,)
The corresponding label of `X_resampled`.
"""
self._validate_params()
return super().fit_resample(X, y)

def _more_tags(self):
return {"X_types": ["2darray", "sparse", "dataframe"]}

Expand Down Expand Up @@ -241,6 +307,13 @@ class FunctionSampler(BaseSampler):

_sampling_type = "bypass"

_parameter_constraints: dict = {
"func": [callable, None],
"accept_sparse": ["boolean"],
"kw_args": [dict, None],
"validate": ["boolean"],
}

def __init__(self, *, func=None, accept_sparse=True, kw_args=None, validate=True):
super().__init__()
self.func = func
Expand All @@ -267,6 +340,7 @@ def fit(self, X, y):
self : object
Return the instance itself.
"""
self._validate_params()
# we need to overwrite SamplerMixin.fit to bypass the validation
if self.validate:
check_classification_targets(y)
Expand Down Expand Up @@ -298,6 +372,7 @@ def fit_resample(self, X, y):
y_resampled : array-like of shape (n_samples_new,)
The corresponding label of `X_resampled`.
"""
self._validate_params()
arrays_transformer = ArraysTransformer(X, y)

if self.validate:
Expand Down
27 changes: 11 additions & 16 deletions imblearn/combine/_smote_enn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Christos Aridas
# License: MIT

import numbers

from sklearn.base import clone
from sklearn.utils import check_X_y

Expand Down Expand Up @@ -102,6 +104,13 @@ class SMOTEENN(BaseSampler):

_sampling_type = "over-sampling"

_parameter_constraints: dict = {
**BaseOverSampler._parameter_constraints,
"smote": [SMOTE, None],
"enn": [EditedNearestNeighbours, None],
"n_jobs": [numbers.Integral, None],
}

def __init__(
self,
*,
Expand All @@ -121,14 +130,7 @@ def __init__(
def _validate_estimator(self):
"Private function to validate SMOTE and ENN objects"
if self.smote is not None:
if isinstance(self.smote, SMOTE):
self.smote_ = clone(self.smote)
else:
raise ValueError(
f"smote needs to be a SMOTE object."
f"Got {type(self.smote)} instead."
)
# Otherwise create a default SMOTE
self.smote_ = clone(self.smote)
else:
self.smote_ = SMOTE(
sampling_strategy=self.sampling_strategy,
Expand All @@ -137,14 +139,7 @@ def _validate_estimator(self):
)

if self.enn is not None:
if isinstance(self.enn, EditedNearestNeighbours):
self.enn_ = clone(self.enn)
else:
raise ValueError(
f"enn needs to be an EditedNearestNeighbours."
f" Got {type(self.enn)} instead."
)
# Otherwise create a default EditedNearestNeighbours
self.enn_ = clone(self.enn)
else:
self.enn_ = EditedNearestNeighbours(
sampling_strategy="all", n_jobs=self.n_jobs
Expand Down
27 changes: 11 additions & 16 deletions imblearn/combine/_smote_tomek.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# Christos Aridas
# License: MIT

import numbers

from sklearn.base import clone
from sklearn.utils import check_X_y

Expand Down Expand Up @@ -100,6 +102,13 @@ class SMOTETomek(BaseSampler):

_sampling_type = "over-sampling"

_parameter_constraints: dict = {
**BaseOverSampler._parameter_constraints,
"smote": [SMOTE, None],
"tomek": [TomekLinks, None],
"n_jobs": [numbers.Integral, None],
}

def __init__(
self,
*,
Expand All @@ -120,14 +129,7 @@ def _validate_estimator(self):
"Private function to validate SMOTE and ENN objects"

if self.smote is not None:
if isinstance(self.smote, SMOTE):
self.smote_ = clone(self.smote)
else:
raise ValueError(
f"smote needs to be a SMOTE object."
f"Got {type(self.smote)} instead."
)
# Otherwise create a default SMOTE
self.smote_ = clone(self.smote)
else:
self.smote_ = SMOTE(
sampling_strategy=self.sampling_strategy,
Expand All @@ -136,14 +138,7 @@ def _validate_estimator(self):
)

if self.tomek is not None:
if isinstance(self.tomek, TomekLinks):
self.tomek_ = clone(self.tomek)
else:
raise ValueError(
f"tomek needs to be a TomekLinks object."
f"Got {type(self.tomek)} instead."
)
# Otherwise create a default TomekLinks
self.tomek_ = clone(self.tomek)
else:
self.tomek_ = TomekLinks(sampling_strategy="all", n_jobs=self.n_jobs)

Expand Down
14 changes: 0 additions & 14 deletions imblearn/combine/tests/test_smote_enn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# License: MIT

import numpy as np
import pytest
from sklearn.utils._testing import assert_allclose, assert_array_equal

from imblearn.combine import SMOTEENN
Expand Down Expand Up @@ -156,16 +155,3 @@ def test_parallelisation():
assert smt.n_jobs == 8
assert smt.smote_.n_jobs == 8
assert smt.enn_.n_jobs == 8


@pytest.mark.parametrize(
"smote_params, err_msg",
[
({"smote": "rnd"}, "smote needs to be a SMOTE"),
({"enn": "rnd"}, "enn needs to be an "),
],
)
def test_error_wrong_object(smote_params, err_msg):
smt = SMOTEENN(**smote_params)
with pytest.raises(ValueError, match=err_msg):
smt.fit_resample(X, Y)
14 changes: 0 additions & 14 deletions imblearn/combine/tests/test_smote_tomek.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# License: MIT

import numpy as np
import pytest
from sklearn.utils._testing import assert_allclose, assert_array_equal

from imblearn.combine import SMOTETomek
Expand Down Expand Up @@ -166,16 +165,3 @@ def test_parallelisation():
assert smt.n_jobs == 8
assert smt.smote_.n_jobs == 8
assert smt.tomek_.n_jobs == 8


@pytest.mark.parametrize(
"smote_params, err_msg",
[
({"smote": "rnd"}, "smote needs to be a SMOTE"),
({"tomek": "rnd"}, "tomek needs to be a TomekLinks"),
],
)
def test_error_wrong_object(smote_params, err_msg):
smt = SMOTETomek(**smote_params)
with pytest.raises(ValueError, match=err_msg):
smt.fit_resample(X, Y)
40 changes: 27 additions & 13 deletions imblearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Christos Aridas
# License: MIT

import copy
import inspect
import numbers
import warnings
Expand All @@ -18,21 +19,23 @@
from sklearn.utils.fixes import delayed
from sklearn.utils.validation import check_is_fitted

from ..base import _ParamsValidationMixin
from ..pipeline import Pipeline
from ..under_sampling import RandomUnderSampler
from ..under_sampling.base import BaseUnderSampler
from ..utils import Substitution, check_sampling_strategy, check_target_type
from ..utils._available_if import available_if
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
from ._common import _estimator_has
from ..utils._param_validation import HasMethods, Interval, StrOptions
from ._common import _bagging_parameter_constraints, _estimator_has


@Substitution(
sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
n_jobs=_n_jobs_docstring,
random_state=_random_state_docstring,
)
class BalancedBaggingClassifier(BaggingClassifier):
class BalancedBaggingClassifier(BaggingClassifier, _ParamsValidationMixin):
"""A Bagging classifier with additional balancing.

This implementation of Bagging is similar to the scikit-learn
Expand Down Expand Up @@ -252,6 +255,26 @@ class BalancedBaggingClassifier(BaggingClassifier):
[ 2 225]]
"""

# make a deepcopy to not modify the original dictionary
if hasattr(BaggingClassifier, "_parameter_constraints"):
# scikit-learn >= 1.2
_parameter_constraints = copy.deepcopy(BaggingClassifier._parameter_constraints)
else:
_parameter_constraints = copy.deepcopy(_bagging_parameter_constraints)

_parameter_constraints.update(
{
"sampling_strategy": [
Interval(numbers.Real, 0, 1, closed="right"),
StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
dict,
callable,
],
"replacement": ["boolean"],
"sampler": [HasMethods(["fit_resample"]), None],
}
)

def __init__(
self,
estimator=None,
Expand Down Expand Up @@ -316,17 +339,7 @@ def _validate_y(self, y):

def _validate_estimator(self, default=DecisionTreeClassifier()):
"""Check the estimator and the n_estimator attribute, set the
`base_estimator_` attribute."""
if not isinstance(self.n_estimators, (numbers.Integral, np.integer)):
raise ValueError(
f"n_estimators must be an integer, " f"got {type(self.n_estimators)}."
)

if self.n_estimators <= 0:
raise ValueError(
f"n_estimators must be greater than zero, " f"got {self.n_estimators}."
)

`estimator_` attribute."""
if self.estimator is not None and (
self.base_estimator not in [None, "deprecated"]
):
Expand Down Expand Up @@ -395,6 +408,7 @@ def fit(self, X, y):
Fitted estimator.
"""
# overwrite the base class method by disallowing `sample_weight`
self._validate_params()
return super().fit(X, y)

def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
Expand Down
Loading