Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT compatibility with sklearn 1.6 #1104

2 changes: 1 addition & 1 deletion examples/ensemble/plot_comparison_ensemble_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@

from imblearn.ensemble import EasyEnsembleClassifier, RUSBoostClassifier

estimator = AdaBoostClassifier(n_estimators=10, algorithm="SAMME")
estimator = AdaBoostClassifier(n_estimators=10)
eec = EasyEnsembleClassifier(n_estimators=10, estimator=estimator)
eec.fit(X_train, y_train)
y_pred_eec = eec.predict(X_test)
Expand Down
28 changes: 24 additions & 4 deletions imblearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
from abc import ABCMeta, abstractmethod

import numpy as np
import sklearn
from sklearn.base import BaseEstimator, OneToOneFeatureMixin
from sklearn.preprocessing import label_binarize
from sklearn.utils.metaestimators import available_if
from sklearn.utils.multiclass import check_classification_targets

from .utils import check_sampling_strategy, check_target_type
from .utils.fixes import check_version_package, validate_data
from .utils._param_validation import validate_parameter_constraints
from .utils._validation import ArraysTransformer


class _ParamsValidationMixin:
"""Mixin class to validate parameters."""

Expand All @@ -35,7 +37,7 @@ class attribute, which is a dictionary `param_name: list of constraints`. See
)


class SamplerMixin(_ParamsValidationMixin, BaseEstimator, metaclass=ABCMeta):
class SamplerMixin(_ParamsValidationMixin, metaclass=ABCMeta):
"""Mixin class for samplers with abstract method.

Warning: This class should not be used directly. Use the derive classes
Expand Down Expand Up @@ -133,7 +135,7 @@ def _fit_resample(self, X, y):
pass


class BaseSampler(SamplerMixin, OneToOneFeatureMixin):
class BaseSampler(SamplerMixin, OneToOneFeatureMixin, BaseEstimator):
"""Base class for sampling algorithms.

Warning: This class should not be used directly. Use the derive classes
Expand All @@ -147,7 +149,7 @@ def _check_X_y(self, X, y, accept_sparse=None):
if accept_sparse is None:
accept_sparse = ["csr", "csc"]
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X, y = self._validate_data(X, y, reset=True, accept_sparse=accept_sparse)
X, y = validate_data(self, X=X, y=y, reset=True, accept_sparse=accept_sparse)
return X, y, binarize_y

def fit(self, X, y):
Expand Down Expand Up @@ -196,9 +198,27 @@ def fit_resample(self, X, y):
self._validate_params()
return super().fit_resample(X, y)

@available_if(check_version_package("sklearn", "<", "1.6"))
def _more_tags(self):
return {"X_types": ["2darray", "sparse", "dataframe"]}

@available_if(check_version_package("sklearn", ">=", "1.6"))
def __sklearn_tags__(self):
from .utils._tags import Tags, SamplerTags, TargetTags, InputTags
tags = Tags(
estimator_type="sampler",
target_tags=TargetTags(required=True),
transformer_tags=None,
regressor_tags=None,
classifier_tags=None,
sampler_tags=SamplerTags(),
)
tags.input_tags = InputTags()
tags.input_tags.two_d_array = True
tags.input_tags.sparse = True
tags.input_tags.dataframe = True
return tags


def _identity(X, y):
return X, y
Expand Down
12 changes: 7 additions & 5 deletions imblearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from ..utils import Substitution, check_sampling_strategy, check_target_type
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
from ..utils._param_validation import HasMethods, Interval, StrOptions
from ..utils.fixes import _fit_context
from ..utils.fixes import _fit_context, check_version_package, validate_data
from ._common import _bagging_parameter_constraints, _estimator_has

sklearn_version = parse_version(sklearn.__version__)
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)


@Substitution(
Expand Down Expand Up @@ -382,12 +382,13 @@ def decision_function(self, X):
check_is_fitted(self)

# Check data
X = self._validate_data(
X,
X = validate_data(
self,
X=X,
accept_sparse=["csr", "csc"],
dtype=None,
force_all_finite=False,
reset=False,
ensure_all_finite=False,
)

# Parallel loop
Expand Down Expand Up @@ -415,6 +416,7 @@ def base_estimator_(self):
)
raise error

@available_if(check_version_package("sklearn", "<", "1.6"))
def _more_tags(self):
tags = super()._more_tags()
tags_key = "_xfail_checks"
Expand Down
26 changes: 18 additions & 8 deletions imblearn/ensemble/_easy_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier
from sklearn.ensemble._bagging import _parallel_decision_function
from sklearn.ensemble._base import _partition_estimators
from sklearn.utils._tags import _safe_tags
from sklearn.utils.fixes import parse_version
from sklearn.utils.metaestimators import available_if
from sklearn.utils.parallel import Parallel, delayed
Expand All @@ -27,11 +26,11 @@
from ..utils import Substitution, check_sampling_strategy, check_target_type
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
from ..utils._param_validation import Interval, StrOptions
from ..utils.fixes import _fit_context
from ..utils.fixes import _fit_context, check_version_package, get_tags, validate_data
from ._common import _bagging_parameter_constraints, _estimator_has

MAX_INT = np.iinfo(np.int32).max
sklearn_version = parse_version(sklearn.__version__)
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)


@Substitution(
Expand Down Expand Up @@ -311,12 +310,13 @@ def decision_function(self, X):
check_is_fitted(self)

# Check data
X = self._validate_data(
X,
X = validate_data(
self,
X=X,
accept_sparse=["csr", "csc"],
dtype=None,
force_all_finite=False,
reset=False,
ensure_all_finite=False,
)

# Parallel loop
Expand Down Expand Up @@ -346,9 +346,19 @@ def base_estimator_(self):

def _get_estimator(self):
if self.estimator is None:
return AdaBoostClassifier(algorithm="SAMME")
if parse_version("1.4") <= sklearn_version < parse_version("1.6"):
return AdaBoostClassifier(algorithm="SAMME")
else:
return AdaBoostClassifier()
return self.estimator

# TODO: remove when minimum supported version of scikit-learn is 1.5
@available_if(check_version_package("sklearn", "<", "1.6"))
def _more_tags(self):
return {"allow_nan": _safe_tags(self._get_estimator(), "allow_nan")}
return {"allow_nan": get_tags(self._get_estimator())["allow_nan"]}

@available_if(check_version_package("sklearn", ">=", "1.6"))
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = get_tags(self._get_estimator()).input_tags.allow_nan
return tags
43 changes: 28 additions & 15 deletions imblearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numbers
from copy import deepcopy
from dataclasses import is_dataclass
from warnings import warn

import numpy as np
Expand All @@ -24,6 +25,7 @@
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import _safe_indexing, check_random_state
from sklearn.utils.fixes import parse_version
from sklearn.utils.metaestimators import available_if
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.parallel import Parallel, delayed
from sklearn.utils.validation import _check_sample_weight
Expand All @@ -35,11 +37,11 @@
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
from ..utils._param_validation import Hidden, Interval, StrOptions
from ..utils._validation import check_sampling_strategy
from ..utils.fixes import _fit_context
from ..utils.fixes import _fit_context, check_version_package, get_tags, validate_data
from ._common import _random_forest_classifier_parameter_constraints

MAX_INT = np.iinfo(np.int32).max
sklearn_version = parse_version(sklearn.__version__)
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)


def _local_parallel_build_trees(
Expand Down Expand Up @@ -77,7 +79,7 @@ def _local_parallel_build_trees(
"bootstrap": bootstrap,
}

if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
if sklearn_version >= parse_version("1.4"):
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
# support for missing values
params_parallel_build_trees["missing_values_in_feature_mask"] = (
Expand Down Expand Up @@ -474,7 +476,7 @@ def __init__(
"max_samples": max_samples,
}
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
if sklearn_version >= parse_version("1.4"):
# use scikit-learn support for monotonic constraints
params_random_forest["monotonic_cst"] = monotonic_cst
else:
Expand Down Expand Up @@ -594,24 +596,25 @@ def fit(self, X, y, sample_weight=None):
if issparse(y):
raise ValueError("sparse multilabel-indicator for y is not supported.")

# TODO: remove when the minimum supported version of scipy will be 1.4
# Support for missing values
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
force_all_finite = False
# TODO (1.6): simplify because we will only have dataclass tags
tags = get_tags(self)
if is_dataclass(tags):
ensure_all_finite = not tags.input_tags.allow_nan
else:
force_all_finite = True
ensure_all_finite = not tags.get("allow_nan", False)

X, y = self._validate_data(
X,
y,
X, y = validate_data(
self,
X=X,
y=y,
multi_output=True,
accept_sparse="csc",
dtype=DTYPE,
force_all_finite=force_all_finite,
ensure_all_finite=ensure_all_finite,
)

# TODO: remove when the minimum supported version of scikit-learn will be 1.4
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
if sklearn_version >= parse_version("1.4"):
# _compute_missing_values_in_feature_mask checks if X has missing values and
# will raise an error if the underlying tree base estimator can't handle
# missing values. Only the criterion is required to determine if the tree
Expand Down Expand Up @@ -880,5 +883,15 @@ def _compute_oob_predictions(self, X, y):

return oob_pred

@available_if(check_version_package("sklearn", "<", "1.6"))
def _more_tags(self):
return {"multioutput": False, "multilabel": False}
allow_nan = sklearn_version >= parse_version("1.4")
return {"multioutput": False, "multilabel": False, "allow_nan": allow_nan}

@available_if(check_version_package("sklearn", ">=", "1.6"))
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.target_tags.multi_output = False
tags.classifier_tags.multi_label = False
tags.input_tags.allow_nan = sklearn_version >= parse_version("1.4")
return tags
34 changes: 26 additions & 8 deletions imblearn/ensemble/_weight_boosting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import numbers
import warnings
from copy import deepcopy

import numpy as np
Expand All @@ -10,6 +11,7 @@
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import _safe_indexing
from sklearn.utils.fixes import parse_version
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import has_fit_parameter

from ..base import _ParamsValidationMixin
Expand All @@ -18,8 +20,8 @@
from ..under_sampling.base import BaseUnderSampler
from ..utils import Substitution, check_target_type
from ..utils._docstring import _random_state_docstring
from ..utils._param_validation import Interval, StrOptions
from ..utils.fixes import _fit_context
from ..utils._param_validation import Hidden, Interval, StrOptions
from ..utils.fixes import _fit_context, check_version_package
from ._common import _adaboost_classifier_parameter_constraints

sklearn_version = parse_version(sklearn.__version__)
Expand Down Expand Up @@ -58,16 +60,15 @@ class RUSBoostClassifier(_ParamsValidationMixin, AdaBoostClassifier):
``learning_rate``. There is a trade-off between ``learning_rate`` and
``n_estimators``.

algorithm : {{'SAMME', 'SAMME.R'}}, default='SAMME.R'
algorithm : {{'SAMME', 'SAMME.R'}}, default='deprecated'
If 'SAMME.R' then use the SAMME.R real boosting algorithm.
``base_estimator`` must support calculation of class probabilities.
If 'SAMME' then use the SAMME discrete boosting algorithm.
The SAMME.R algorithm typically converges faster than SAMME,
achieving a lower test error with fewer boosting iterations.

.. deprecated:: 0.12
`"SAMME.R"` is deprecated and will be removed in version 0.14.
'"SAMME"' will become the default.
`algorithm` is deprecated in 0.12 and will be removed 0.14.

{sampling_strategy}

Expand Down Expand Up @@ -109,7 +110,7 @@ class RUSBoostClassifier(_ParamsValidationMixin, AdaBoostClassifier):
ensemble.

feature_importances_ : ndarray of shape (n_features,)
The feature importances if supported by the ``base_estimator``.
The feature importances if supported by the ``estimator``.

n_features_in_ : int
Number of features in the input dataset.
Expand Down Expand Up @@ -167,6 +168,10 @@ class RUSBoostClassifier(_ParamsValidationMixin, AdaBoostClassifier):

_parameter_constraints.update(
{
"algorithm": [
StrOptions({"SAMME", "SAMME.R"}),
Hidden(StrOptions({"deprecated"})),
],
"sampling_strategy": [
Interval(numbers.Real, 0, 1, closed="right"),
StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
Expand All @@ -186,17 +191,17 @@ def __init__(
*,
n_estimators=50,
learning_rate=1.0,
algorithm="SAMME.R",
algorithm="deprecated",
sampling_strategy="auto",
replacement=False,
random_state=None,
):
super().__init__(
n_estimators=n_estimators,
learning_rate=learning_rate,
algorithm=algorithm,
random_state=random_state,
)
self.algorithm = algorithm
self.estimator = estimator
self.sampling_strategy = sampling_strategy
self.replacement = replacement
Expand Down Expand Up @@ -394,3 +399,16 @@ def _boost_discrete(self, iboost, X, y, sample_weight, random_state):
sample_weight *= np.exp(estimator_weight * incorrect * (sample_weight > 0))

return sample_weight, estimator_weight, estimator_error

def _boost(self, iboost, X, y, sample_weight, random_state):
if self.algorithm != "deprecated":
warnings.warn(
"`algorithm` parameter is deprecated in 0.12 and will be removed in "
"0.14. In the future, the SAMME algorithm will always be used.",
FutureWarning,
)
if self.algorithm == "SAMME.R":
return self._boost_real(iboost, X, y, sample_weight, random_state)

else: # elif self.algorithm == "SAMME":
return self._boost_discrete(iboost, X, y, sample_weight, random_state)
Loading
Loading