diff --git a/doc/whats_new/v0.11.rst b/doc/whats_new/v0.11.rst index a6cb21821..8e54a275b 100644 --- a/doc/whats_new/v0.11.rst +++ b/doc/whats_new/v0.11.rst @@ -43,3 +43,8 @@ Enhancements parameters. A new fitted parameter `categorical_encoder_` is exposed to access the fitted encoder. :pr:`1001` by :user:`Guillaume Lemaitre `. + +- :class:`~imblearn.under_sampling.RandomUnderSampler` and + :class:`~imblearn.over_sampling.RandomOverSampler` (when `shrinkage is not + None`) now accept any data types and will not attempt any data conversion. + :pr:`1004` by :user:`Guillaume Lemaitre `. diff --git a/examples/api/plot_sampling_strategy_usage.py b/examples/api/plot_sampling_strategy_usage.py index b739a41c6..dbb52fcdf 100644 --- a/examples/api/plot_sampling_strategy_usage.py +++ b/examples/api/plot_sampling_strategy_usage.py @@ -59,10 +59,9 @@ # resampling and the number of samples in the minority class, respectively. # %% -import numpy as np # select only 2 classes since the ratio make sense in this case -binary_mask = np.bitwise_or(y == 0, y == 2) +binary_mask = y.isin([0, 1]) binary_y = y[binary_mask] binary_X = X[binary_mask] diff --git a/imblearn/datasets/tests/test_imbalance.py b/imblearn/datasets/tests/test_imbalance.py index 2d8e278fa..ac3b417c7 100644 --- a/imblearn/datasets/tests/test_imbalance.py +++ b/imblearn/datasets/tests/test_imbalance.py @@ -67,11 +67,14 @@ def test_make_imbalance_dict(iris, sampling_strategy, expected_counts): ], ) def test_make_imbalanced_iris(as_frame, sampling_strategy, expected_counts): - pytest.importorskip("pandas") - iris = load_iris(as_frame=True) + pd = pytest.importorskip("pandas") + iris = load_iris(as_frame=as_frame) X, y = iris.data, iris.target y = iris.target_names[iris.target] + if as_frame: + y = pd.Series(iris.target_names[iris.target], name="target") X_res, y_res = make_imbalance(X, y, sampling_strategy=sampling_strategy) if as_frame: assert hasattr(X_res, "loc") + pd.testing.assert_index_equal(X_res.index, y_res.index) assert Counter(y_res) == expected_counts diff --git a/imblearn/ensemble/tests/test_bagging.py b/imblearn/ensemble/tests/test_bagging.py index 1f1d408ef..01532add8 100644 --- a/imblearn/ensemble/tests/test_bagging.py +++ b/imblearn/ensemble/tests/test_bagging.py @@ -572,11 +572,12 @@ def roughly_balanced_bagging(X, y, replace=False): # Roughly Balanced Bagging rbb = BalancedBaggingClassifier( - estimator=CountDecisionTreeClassifier(), + estimator=CountDecisionTreeClassifier(random_state=0), n_estimators=2, sampler=FunctionSampler( func=roughly_balanced_bagging, kw_args={"replace": replace} ), + random_state=0, ) rbb.fit(X, y) diff --git a/imblearn/over_sampling/_random_over_sampler.py b/imblearn/over_sampling/_random_over_sampler.py index 7175855ea..63b5a66a7 100644 --- a/imblearn/over_sampling/_random_over_sampler.py +++ b/imblearn/over_sampling/_random_over_sampler.py @@ -15,6 +15,7 @@ from ..utils import Substitution, check_target_type from ..utils._docstring import _random_state_docstring from ..utils._param_validation import Interval +from ..utils._validation import _check_X from .base import BaseOverSampler @@ -154,14 +155,9 @@ def __init__( def _check_X_y(self, X, y): y, binarize_y = check_target_type(y, indicate_one_vs_all=True) - X, y = self._validate_data( - X, - y, - reset=True, - accept_sparse=["csr", "csc"], - dtype=None, - force_all_finite=False, - ) + X = _check_X(X) + self._check_n_features(X, reset=True) + self._check_feature_names(X, reset=True) return X, y, binarize_y def _fit_resample(self, X, y): @@ -258,4 +254,7 @@ def _more_tags(self): "X_types": ["2darray", "string", "sparse", "dataframe"], "sample_indices": True, "allow_nan": True, + "_xfail_checks": { + "check_complex_data": "Robust to this type of data.", + }, } diff --git a/imblearn/over_sampling/_smote/base.py b/imblearn/over_sampling/_smote/base.py index da880f87c..f26b91b19 100644 --- a/imblearn/over_sampling/_smote/base.py +++ b/imblearn/over_sampling/_smote/base.py @@ -27,6 +27,7 @@ from ...utils import Substitution, check_neighbors_object, check_target_type from ...utils._docstring import _n_jobs_docstring, _random_state_docstring from ...utils._param_validation import HasMethods, Interval +from ...utils._validation import _check_X from ...utils.fixes import _mode from ..base import BaseOverSampler @@ -559,9 +560,9 @@ def _check_X_y(self, X, y): features. """ y, binarize_y = check_target_type(y, indicate_one_vs_all=True) - if not (hasattr(X, "__array__") or sparse.issparse(X)): - X = check_array(X, dtype=object) + X = _check_X(X) self._check_n_features(X, reset=True) + self._check_feature_names(X, reset=True) return X, y, binarize_y def _validate_estimator(self): diff --git a/imblearn/over_sampling/tests/test_random_over_sampler.py b/imblearn/over_sampling/tests/test_random_over_sampler.py index b72132d19..6ad4b75ef 100644 --- a/imblearn/over_sampling/tests/test_random_over_sampler.py +++ b/imblearn/over_sampling/tests/test_random_over_sampler.py @@ -4,6 +4,7 @@ # License: MIT from collections import Counter +from datetime import datetime import numpy as np import pytest @@ -273,3 +274,16 @@ def test_random_over_sampler_strings(sampling_strategy): random_state=0, ) RandomOverSampler(sampling_strategy=sampling_strategy).fit_resample(X, y) + + +def test_random_over_sampling_datetime(): + """Check that we don't convert input data and only sample from it.""" + pd = pytest.importorskip("pandas") + X = pd.DataFrame({"label": [0, 0, 0, 1], "td": [datetime.now()] * 4}) + y = X["label"] + ros = RandomOverSampler(random_state=0) + X_res, y_res = ros.fit_resample(X, y) + + pd.testing.assert_series_equal(X_res.dtypes, X.dtypes) + pd.testing.assert_index_equal(X_res.index, y_res.index) + assert_array_equal(y_res.to_numpy(), np.array([0, 0, 0, 1, 1, 1])) diff --git a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py index ed47fe586..876195a6d 100644 --- a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py +++ b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py @@ -9,6 +9,7 @@ from ...utils import Substitution, check_target_type from ...utils._docstring import _random_state_docstring +from ...utils._validation import _check_X from ..base import BaseUnderSampler @@ -97,14 +98,9 @@ def __init__( def _check_X_y(self, X, y): y, binarize_y = check_target_type(y, indicate_one_vs_all=True) - X, y = self._validate_data( - X, - y, - reset=True, - accept_sparse=["csr", "csc"], - dtype=None, - force_all_finite=False, - ) + X = _check_X(X) + self._check_n_features(X, reset=True) + self._check_feature_names(X, reset=True) return X, y, binarize_y def _fit_resample(self, X, y): @@ -140,4 +136,7 @@ def _more_tags(self): "X_types": ["2darray", "string", "sparse", "dataframe"], "sample_indices": True, "allow_nan": True, + "_xfail_checks": { + "check_complex_data": "Robust to this type of data.", + }, } diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py b/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py index bcb8682e2..9fc9f084c 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py @@ -4,6 +4,7 @@ # License: MIT from collections import Counter +from datetime import datetime import numpy as np import pytest @@ -148,3 +149,16 @@ def test_random_under_sampler_strings(sampling_strategy): random_state=0, ) RandomUnderSampler(sampling_strategy=sampling_strategy).fit_resample(X, y) + + +def test_random_under_sampling_datetime(): + """Check that we don't convert input data and only sample from it.""" + pd = pytest.importorskip("pandas") + X = pd.DataFrame({"label": [0, 0, 0, 1], "td": [datetime.now()] * 4}) + y = X["label"] + rus = RandomUnderSampler(random_state=0) + X_res, y_res = rus.fit_resample(X, y) + + pd.testing.assert_series_equal(X_res.dtypes, X.dtypes) + pd.testing.assert_index_equal(X_res.index, y_res.index) + assert_array_equal(y_res.to_numpy(), np.array([0, 1])) diff --git a/imblearn/utils/_validation.py b/imblearn/utils/_validation.py index b8df4b926..da1e492f4 100644 --- a/imblearn/utils/_validation.py +++ b/imblearn/utils/_validation.py @@ -12,8 +12,11 @@ import numpy as np from sklearn.base import clone from sklearn.neighbors import NearestNeighbors -from sklearn.utils import column_or_1d +from sklearn.utils import check_array, column_or_1d from sklearn.utils.multiclass import type_of_target +from sklearn.utils.validation import _num_samples + +from .fixes import _is_pandas_df SAMPLING_KIND = ( "over-sampling", @@ -35,6 +38,12 @@ def __init__(self, X, y): def transform(self, X, y): X = self._transfrom_one(X, self.x_props) y = self._transfrom_one(y, self.y_props) + if self.x_props["type"].lower() == "dataframe" and self.y_props[ + "type" + ].lower() in {"series", "dataframe"}: + # We lost the y.index during resampling. We can safely use X.index to align + # them. + y.index = X.index return X, y def _gets_props(self, array): @@ -607,3 +616,18 @@ def inner_f(*args, **kwargs): return f(**kwargs) return inner_f + + +def _check_X(X): + """Check X and do not check it if a dataframe.""" + n_samples = _num_samples(X) + if n_samples < 1: + raise ValueError( + f"Found array with {n_samples} sample(s) while a minimum of 1 is " + "required." + ) + if _is_pandas_df(X): + return X + return check_array( + X, dtype=None, accept_sparse=["csr", "csc"], force_all_finite=False + ) diff --git a/imblearn/utils/fixes.py b/imblearn/utils/fixes.py index 1868cb1fd..023d8a152 100644 --- a/imblearn/utils/fixes.py +++ b/imblearn/utils/fixes.py @@ -5,6 +5,7 @@ which the fix is no longer needed. """ import functools +import sys import numpy as np import scipy @@ -132,3 +133,18 @@ def _is_fitted(estimator, attributes=None, all_or_any=all): else: from sklearn.utils.validation import _is_fitted # type: ignore[no-redef] + +try: + from sklearn.utils.validation import _is_pandas_df +except ImportError: + + def _is_pandas_df(X): + """Return True if the X is a pandas dataframe.""" + if hasattr(X, "columns") and hasattr(X, "iloc"): + # Likely a pandas DataFrame, we explicitly check the type to confirm. + try: + pd = sys.modules["pandas"] + except KeyError: + return False + return isinstance(X, pd.DataFrame) + return False