Skip to content

Commit a84b63f

Browse files
authored
MAINT add parameter validation framework (#955)
1 parent 063da64 commit a84b63f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+2303
-453
lines changed

doc/whats_new/v0.10.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ Compatibility
1919
- Maintenance release for be compatible with scikit-learn >= 1.0.2.
2020
:pr:`946`, :pr:`947`, :pr:`949` by :user:`Guillaume Lemaitre <glemaitre>`.
2121

22+
- Add support for automatic parameters validation as in scikit-learn >= 1.2.
23+
:pr:`955` by :user:`Guillaume Lemaitre <glemaitre>`.
24+
2225
Deprecation
2326
...........
2427

imblearn/base.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sklearn.utils.multiclass import check_classification_targets
1313

1414
from .utils import check_sampling_strategy, check_target_type
15+
from .utils._param_validation import validate_parameter_constraints
1516
from .utils._validation import ArraysTransformer
1617

1718

@@ -113,7 +114,26 @@ def _fit_resample(self, X, y):
113114
pass
114115

115116

116-
class BaseSampler(SamplerMixin):
117+
class _ParamsValidationMixin:
118+
"""Mixin class to validate parameters."""
119+
120+
def _validate_params(self):
121+
"""Validate types and values of constructor parameters.
122+
123+
The expected type and values must be defined in the `_parameter_constraints`
124+
class attribute, which is a dictionary `param_name: list of constraints`. See
125+
the docstring of `validate_parameter_constraints` for a description of the
126+
accepted constraints.
127+
"""
128+
if hasattr(self, "_parameter_constraints"):
129+
validate_parameter_constraints(
130+
self._parameter_constraints,
131+
self.get_params(deep=False),
132+
caller_name=self.__class__.__name__,
133+
)
134+
135+
136+
class BaseSampler(SamplerMixin, _ParamsValidationMixin):
117137
"""Base class for sampling algorithms.
118138
119139
Warning: This class should not be used directly. Use the derive classes
@@ -130,6 +150,52 @@ def _check_X_y(self, X, y, accept_sparse=None):
130150
X, y = self._validate_data(X, y, reset=True, accept_sparse=accept_sparse)
131151
return X, y, binarize_y
132152

153+
def fit(self, X, y):
154+
"""Check inputs and statistics of the sampler.
155+
156+
You should use ``fit_resample`` in all cases.
157+
158+
Parameters
159+
----------
160+
X : {array-like, dataframe, sparse matrix} of shape \
161+
(n_samples, n_features)
162+
Data array.
163+
164+
y : array-like of shape (n_samples,)
165+
Target array.
166+
167+
Returns
168+
-------
169+
self : object
170+
Return the instance itself.
171+
"""
172+
self._validate_params()
173+
return super().fit(X, y)
174+
175+
def fit_resample(self, X, y):
176+
"""Resample the dataset.
177+
178+
Parameters
179+
----------
180+
X : {array-like, dataframe, sparse matrix} of shape \
181+
(n_samples, n_features)
182+
Matrix containing the data which have to be sampled.
183+
184+
y : array-like of shape (n_samples,)
185+
Corresponding label for each sample in X.
186+
187+
Returns
188+
-------
189+
X_resampled : {array-like, dataframe, sparse matrix} of shape \
190+
(n_samples_new, n_features)
191+
The array containing the resampled data.
192+
193+
y_resampled : array-like of shape (n_samples_new,)
194+
The corresponding label of `X_resampled`.
195+
"""
196+
self._validate_params()
197+
return super().fit_resample(X, y)
198+
133199
def _more_tags(self):
134200
return {"X_types": ["2darray", "sparse", "dataframe"]}
135201

@@ -241,6 +307,13 @@ class FunctionSampler(BaseSampler):
241307

242308
_sampling_type = "bypass"
243309

310+
_parameter_constraints: dict = {
311+
"func": [callable, None],
312+
"accept_sparse": ["boolean"],
313+
"kw_args": [dict, None],
314+
"validate": ["boolean"],
315+
}
316+
244317
def __init__(self, *, func=None, accept_sparse=True, kw_args=None, validate=True):
245318
super().__init__()
246319
self.func = func
@@ -267,6 +340,7 @@ def fit(self, X, y):
267340
self : object
268341
Return the instance itself.
269342
"""
343+
self._validate_params()
270344
# we need to overwrite SamplerMixin.fit to bypass the validation
271345
if self.validate:
272346
check_classification_targets(y)
@@ -298,6 +372,7 @@ def fit_resample(self, X, y):
298372
y_resampled : array-like of shape (n_samples_new,)
299373
The corresponding label of `X_resampled`.
300374
"""
375+
self._validate_params()
301376
arrays_transformer = ArraysTransformer(X, y)
302377

303378
if self.validate:

imblearn/combine/_smote_enn.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# Christos Aridas
55
# License: MIT
66

7+
import numbers
8+
79
from sklearn.base import clone
810
from sklearn.utils import check_X_y
911

@@ -102,6 +104,13 @@ class SMOTEENN(BaseSampler):
102104

103105
_sampling_type = "over-sampling"
104106

107+
_parameter_constraints: dict = {
108+
**BaseOverSampler._parameter_constraints,
109+
"smote": [SMOTE, None],
110+
"enn": [EditedNearestNeighbours, None],
111+
"n_jobs": [numbers.Integral, None],
112+
}
113+
105114
def __init__(
106115
self,
107116
*,
@@ -121,14 +130,7 @@ def __init__(
121130
def _validate_estimator(self):
122131
"Private function to validate SMOTE and ENN objects"
123132
if self.smote is not None:
124-
if isinstance(self.smote, SMOTE):
125-
self.smote_ = clone(self.smote)
126-
else:
127-
raise ValueError(
128-
f"smote needs to be a SMOTE object."
129-
f"Got {type(self.smote)} instead."
130-
)
131-
# Otherwise create a default SMOTE
133+
self.smote_ = clone(self.smote)
132134
else:
133135
self.smote_ = SMOTE(
134136
sampling_strategy=self.sampling_strategy,
@@ -137,14 +139,7 @@ def _validate_estimator(self):
137139
)
138140

139141
if self.enn is not None:
140-
if isinstance(self.enn, EditedNearestNeighbours):
141-
self.enn_ = clone(self.enn)
142-
else:
143-
raise ValueError(
144-
f"enn needs to be an EditedNearestNeighbours."
145-
f" Got {type(self.enn)} instead."
146-
)
147-
# Otherwise create a default EditedNearestNeighbours
142+
self.enn_ = clone(self.enn)
148143
else:
149144
self.enn_ = EditedNearestNeighbours(
150145
sampling_strategy="all", n_jobs=self.n_jobs

imblearn/combine/_smote_tomek.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# Christos Aridas
66
# License: MIT
77

8+
import numbers
9+
810
from sklearn.base import clone
911
from sklearn.utils import check_X_y
1012

@@ -100,6 +102,13 @@ class SMOTETomek(BaseSampler):
100102

101103
_sampling_type = "over-sampling"
102104

105+
_parameter_constraints: dict = {
106+
**BaseOverSampler._parameter_constraints,
107+
"smote": [SMOTE, None],
108+
"tomek": [TomekLinks, None],
109+
"n_jobs": [numbers.Integral, None],
110+
}
111+
103112
def __init__(
104113
self,
105114
*,
@@ -120,14 +129,7 @@ def _validate_estimator(self):
120129
"Private function to validate SMOTE and ENN objects"
121130

122131
if self.smote is not None:
123-
if isinstance(self.smote, SMOTE):
124-
self.smote_ = clone(self.smote)
125-
else:
126-
raise ValueError(
127-
f"smote needs to be a SMOTE object."
128-
f"Got {type(self.smote)} instead."
129-
)
130-
# Otherwise create a default SMOTE
132+
self.smote_ = clone(self.smote)
131133
else:
132134
self.smote_ = SMOTE(
133135
sampling_strategy=self.sampling_strategy,
@@ -136,14 +138,7 @@ def _validate_estimator(self):
136138
)
137139

138140
if self.tomek is not None:
139-
if isinstance(self.tomek, TomekLinks):
140-
self.tomek_ = clone(self.tomek)
141-
else:
142-
raise ValueError(
143-
f"tomek needs to be a TomekLinks object."
144-
f"Got {type(self.tomek)} instead."
145-
)
146-
# Otherwise create a default TomekLinks
141+
self.tomek_ = clone(self.tomek)
147142
else:
148143
self.tomek_ = TomekLinks(sampling_strategy="all", n_jobs=self.n_jobs)
149144

imblearn/combine/tests/test_smote_enn.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# License: MIT
55

66
import numpy as np
7-
import pytest
87
from sklearn.utils._testing import assert_allclose, assert_array_equal
98

109
from imblearn.combine import SMOTEENN
@@ -156,16 +155,3 @@ def test_parallelisation():
156155
assert smt.n_jobs == 8
157156
assert smt.smote_.n_jobs == 8
158157
assert smt.enn_.n_jobs == 8
159-
160-
161-
@pytest.mark.parametrize(
162-
"smote_params, err_msg",
163-
[
164-
({"smote": "rnd"}, "smote needs to be a SMOTE"),
165-
({"enn": "rnd"}, "enn needs to be an "),
166-
],
167-
)
168-
def test_error_wrong_object(smote_params, err_msg):
169-
smt = SMOTEENN(**smote_params)
170-
with pytest.raises(ValueError, match=err_msg):
171-
smt.fit_resample(X, Y)

imblearn/combine/tests/test_smote_tomek.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# License: MIT
55

66
import numpy as np
7-
import pytest
87
from sklearn.utils._testing import assert_allclose, assert_array_equal
98

109
from imblearn.combine import SMOTETomek
@@ -166,16 +165,3 @@ def test_parallelisation():
166165
assert smt.n_jobs == 8
167166
assert smt.smote_.n_jobs == 8
168167
assert smt.tomek_.n_jobs == 8
169-
170-
171-
@pytest.mark.parametrize(
172-
"smote_params, err_msg",
173-
[
174-
({"smote": "rnd"}, "smote needs to be a SMOTE"),
175-
({"tomek": "rnd"}, "tomek needs to be a TomekLinks"),
176-
],
177-
)
178-
def test_error_wrong_object(smote_params, err_msg):
179-
smt = SMOTETomek(**smote_params)
180-
with pytest.raises(ValueError, match=err_msg):
181-
smt.fit_resample(X, Y)

imblearn/ensemble/_bagging.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Christos Aridas
55
# License: MIT
66

7+
import copy
78
import inspect
89
import numbers
910
import warnings
@@ -18,21 +19,23 @@
1819
from sklearn.utils.fixes import delayed
1920
from sklearn.utils.validation import check_is_fitted
2021

22+
from ..base import _ParamsValidationMixin
2123
from ..pipeline import Pipeline
2224
from ..under_sampling import RandomUnderSampler
2325
from ..under_sampling.base import BaseUnderSampler
2426
from ..utils import Substitution, check_sampling_strategy, check_target_type
2527
from ..utils._available_if import available_if
2628
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
27-
from ._common import _estimator_has
29+
from ..utils._param_validation import HasMethods, Interval, StrOptions
30+
from ._common import _bagging_parameter_constraints, _estimator_has
2831

2932

3033
@Substitution(
3134
sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
3235
n_jobs=_n_jobs_docstring,
3336
random_state=_random_state_docstring,
3437
)
35-
class BalancedBaggingClassifier(BaggingClassifier):
38+
class BalancedBaggingClassifier(BaggingClassifier, _ParamsValidationMixin):
3639
"""A Bagging classifier with additional balancing.
3740
3841
This implementation of Bagging is similar to the scikit-learn
@@ -252,6 +255,26 @@ class BalancedBaggingClassifier(BaggingClassifier):
252255
[ 2 225]]
253256
"""
254257

258+
# make a deepcopy to not modify the original dictionary
259+
if hasattr(BaggingClassifier, "_parameter_constraints"):
260+
# scikit-learn >= 1.2
261+
_parameter_constraints = copy.deepcopy(BaggingClassifier._parameter_constraints)
262+
else:
263+
_parameter_constraints = copy.deepcopy(_bagging_parameter_constraints)
264+
265+
_parameter_constraints.update(
266+
{
267+
"sampling_strategy": [
268+
Interval(numbers.Real, 0, 1, closed="right"),
269+
StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
270+
dict,
271+
callable,
272+
],
273+
"replacement": ["boolean"],
274+
"sampler": [HasMethods(["fit_resample"]), None],
275+
}
276+
)
277+
255278
def __init__(
256279
self,
257280
estimator=None,
@@ -316,17 +339,7 @@ def _validate_y(self, y):
316339

317340
def _validate_estimator(self, default=DecisionTreeClassifier()):
318341
"""Check the estimator and the n_estimator attribute, set the
319-
`base_estimator_` attribute."""
320-
if not isinstance(self.n_estimators, (numbers.Integral, np.integer)):
321-
raise ValueError(
322-
f"n_estimators must be an integer, " f"got {type(self.n_estimators)}."
323-
)
324-
325-
if self.n_estimators <= 0:
326-
raise ValueError(
327-
f"n_estimators must be greater than zero, " f"got {self.n_estimators}."
328-
)
329-
342+
`estimator_` attribute."""
330343
if self.estimator is not None and (
331344
self.base_estimator not in [None, "deprecated"]
332345
):
@@ -395,6 +408,7 @@ def fit(self, X, y):
395408
Fitted estimator.
396409
"""
397410
# overwrite the base class method by disallowing `sample_weight`
411+
self._validate_params()
398412
return super().fit(X, y)
399413

400414
def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):

0 commit comments

Comments
 (0)