diff --git a/doc/over_sampling.rst b/doc/over_sampling.rst index 3c3ca80b7..3bc975b89 100644 --- a/doc/over_sampling.rst +++ b/doc/over_sampling.rst @@ -203,9 +203,9 @@ or relying on `dtype` inference if the columns are using the >>> print(sorted(Counter(y_resampled).items())) [(0, 30), (1, 30)] >>> print(X_resampled[-5:]) - [['A' 0.52... 2] + [['A' 0.19... 2] ['B' -0.36... 2] - ['B' 0.93... 2] + ['B' 0.87... 2] ['B' 0.37... 2] ['B' 0.33... 2]] diff --git a/doc/whats_new/v0.11.rst b/doc/whats_new/v0.11.rst index 387eb1ed3..666be580c 100644 --- a/doc/whats_new/v0.11.rst +++ b/doc/whats_new/v0.11.rst @@ -14,6 +14,15 @@ Bug fixes they are plugged into an Euclidean distance computation. :pr:`1014` by :user:`Guillaume Lemaitre `. +- Fix a bug in :class:`~imblearn.over_sampling.SMOTENC` where the median of standard + deviation of the continuous features was only computed on the minority class. Now, + we are computing this statistic for each class that is up-sampled. + :pr:`1015` by :user:`Guillaume Lemaitre `. + +- Fix a bug in :class:`~imblearn.over_sampling.SMOTENC` such that the case where + the median of standard deviation of the continuous features is null is handled + in the multiclass case as well. + :pr:`1015` by :user:`Guillaume Lemaitre `. Version 0.11.0 ============== diff --git a/imblearn/over_sampling/_smote/base.py b/imblearn/over_sampling/_smote/base.py index 4627d52be..c4721ca20 100644 --- a/imblearn/over_sampling/_smote/base.py +++ b/imblearn/over_sampling/_smote/base.py @@ -9,7 +9,6 @@ import math import numbers import warnings -from collections import Counter import numpy as np from scipy import sparse @@ -23,7 +22,6 @@ check_random_state, ) from sklearn.utils.sparsefuncs_fast import ( - csc_mean_variance_axis0, csr_mean_variance_axis0, ) from sklearn.utils.validation import _num_features @@ -116,11 +114,11 @@ def _make_samples( rows = np.floor_divide(samples_indices, nn_num.shape[1]) cols = np.mod(samples_indices, nn_num.shape[1]) - X_new = self._generate_samples(X, nn_data, nn_num, rows, cols, steps) + X_new = self._generate_samples(X, nn_data, nn_num, rows, cols, steps, y_type) y_new = np.full(n_samples, fill_value=y_type, dtype=y_dtype) return X_new, y_new - def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps): + def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps, y_type=None): r"""Generate a synthetic sample. The rule for the generation is: @@ -155,6 +153,9 @@ def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps): steps : ndarray of shape (n_samples,), dtype=float Step sizes for new samples. + y_type : None + Unused parameter. Only for compatibility reason with SMOTE-NC. + Returns ------- X_new : {ndarray, sparse matrix} of shape (n_samples, n_features) @@ -465,8 +466,9 @@ class SMOTENC(SMOTE): continuous_features_ : ndarray of shape (n_cont_features,), dtype=np.int64 Indices of the continuous features. - median_std_ : float - Median of the standard deviation of the continuous features. + median_std_ : dict of int -> float + Median of the standard deviation of the continuous features for each + class to be over-sampled. n_features_ : int Number of features observed at `fit`. @@ -627,23 +629,8 @@ def _fit_resample(self, X, y): self._validate_column_types(X) self._validate_estimator() - # compute the median of the standard deviation of the minority class - target_stats = Counter(y) - class_minority = min(target_stats, key=target_stats.get) - X_continuous = _safe_indexing(X, self.continuous_features_, axis=1) X_continuous = check_array(X_continuous, accept_sparse=["csr", "csc"]) - X_minority = _safe_indexing(X_continuous, np.flatnonzero(y == class_minority)) - - if sparse.issparse(X): - if X.format == "csr": - _, var = csr_mean_variance_axis0(X_minority) - else: - _, var = csc_mean_variance_axis0(X_minority) - else: - var = X_minority.var(axis=0) - self.median_std_ = np.median(np.sqrt(var)) - X_categorical = _safe_indexing(X, self.categorical_features_, axis=1) if X_continuous.dtype.name != "object": dtype_ohe = X_continuous.dtype @@ -664,28 +651,54 @@ def _fit_resample(self, X, y): if not sparse.issparse(X_ohe): X_ohe = sparse.csr_matrix(X_ohe, dtype=dtype_ohe) - # we can replace the 1 entries of the categorical features with the - # median of the standard deviation. It will ensure that whenever - # distance is computed between 2 samples, the difference will be equal - # to the median of the standard deviation as in the original paper. - - # In the edge case where the median of the std is equal to 0, the 1s - # entries will be also nullified. In this case, we store the original - # categorical encoding which will be later used for inverting the OHE - if math.isclose(self.median_std_, 0): - self._X_categorical_minority_encoded = _safe_indexing( - X_ohe.toarray(), np.flatnonzero(y == class_minority) + X_encoded = sparse.hstack((X_continuous, X_ohe), format="csr", dtype=dtype_ohe) + X_resampled = [X_encoded.copy()] + y_resampled = [y.copy()] + + # SMOTE resampling starts here + self.median_std_ = {} + for class_sample, n_samples in self.sampling_strategy_.items(): + if n_samples == 0: + continue + target_class_indices = np.flatnonzero(y == class_sample) + X_class = _safe_indexing(X_encoded, target_class_indices) + + _, var = csr_mean_variance_axis0( + X_class[:, : self.continuous_features_.size] ) + self.median_std_[class_sample] = np.median(np.sqrt(var)) + + # In the edge case where the median of the std is equal to 0, the 1s + # entries will be also nullified. In this case, we store the original + # categorical encoding which will be later used for inverting the OHE + if math.isclose(self.median_std_[class_sample], 0): + # This variable will be used when generating data + self._X_categorical_minority_encoded = X_class[ + :, self.continuous_features_.size : + ].toarray() + + # we can replace the 1 entries of the categorical features with the + # median of the standard deviation. It will ensure that whenever + # distance is computed between 2 samples, the difference will be equal + # to the median of the standard deviation as in the original paper. + X_class_categorical = X_class[:, self.continuous_features_.size :] + # With one-hot encoding, the median will be repeated twice. We need + # to divide by sqrt(2) such that we only have one median value + # contributing to the Euclidean distance + X_class_categorical.data[:] = self.median_std_[class_sample] / np.sqrt(2) + X_class[:, self.continuous_features_.size :] = X_class_categorical - # With one-hot encoding, the median will be repeated twice. We need to divide - # by sqrt(2) such that we only have one median value contributing to the - # Euclidean distance - X_ohe.data = ( - np.ones_like(X_ohe.data, dtype=X_ohe.dtype) * self.median_std_ / np.sqrt(2) - ) - X_encoded = sparse.hstack((X_continuous, X_ohe), format="csr") + self.nn_k_.fit(X_class) + nns = self.nn_k_.kneighbors(X_class, return_distance=False)[:, 1:] + X_new, y_new = self._make_samples( + X_class, y.dtype, class_sample, X_class, nns, n_samples, 1.0 + ) + X_resampled.append(X_new) + y_resampled.append(y_new) - X_resampled, y_resampled = super()._fit_resample(X_encoded, y) + X_resampled = sparse.vstack(X_resampled, format=X_encoded.format) + y_resampled = np.hstack(y_resampled) + # SMOTE resampling ends here # reverse the encoding of the categorical features X_res_cat = X_resampled[:, self.continuous_features_.size :] @@ -723,7 +736,7 @@ def _fit_resample(self, X, y): return X_resampled, y_resampled - def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps): + def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps, y_type): """Generate a synthetic sample with an additional steps for the categorical features. @@ -741,7 +754,7 @@ def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps): # In the case that the median std was equal to zeros, we have to # create non-null entry based on the encoded of OHE - if math.isclose(self.median_std_, 0): + if math.isclose(self.median_std_[y_type], 0): nn_data[ :, self.continuous_features_.size : ] = self._X_categorical_minority_encoded diff --git a/imblearn/over_sampling/_smote/tests/test_smote_nc.py b/imblearn/over_sampling/_smote/tests/test_smote_nc.py index 84dd6c252..1314ea98b 100644 --- a/imblearn/over_sampling/_smote/tests/test_smote_nc.py +++ b/imblearn/over_sampling/_smote/tests/test_smote_nc.py @@ -130,6 +130,8 @@ def test_smotenc(data): assert set(X[:, cat_idx]) == set(X_resampled[:, cat_idx]) assert X[:, cat_idx].dtype == X_resampled[:, cat_idx].dtype + assert isinstance(smote.median_std_, dict) + # part of the common test which apply to SMOTE-NC even if it is not default # constructible @@ -193,6 +195,7 @@ def test_smotenc_pandas(): X_res, y_res = smote.fit_resample(X, y) assert_array_equal(X_res_pd.to_numpy(), X_res) assert_allclose(y_res_pd, y_res) + assert set(smote.median_std_.keys()) == {0, 1} def test_smotenc_preserve_dtype(): @@ -234,20 +237,36 @@ def test_smote_nc_with_null_median_std(): [ [1, 2, 1, "A"], [2, 1, 2, "A"], + [2, 1, 2, "A"], [1, 2, 3, "B"], [1, 2, 4, "C"], [1, 2, 5, "C"], + [1, 2, 4, "C"], + [1, 2, 4, "C"], + [1, 2, 4, "C"], ], dtype="object", ) labels = np.array( - ["class_1", "class_1", "class_1", "class_2", "class_2"], dtype=object + [ + "class_1", + "class_1", + "class_1", + "class_1", + "class_2", + "class_2", + "class_3", + "class_3", + "class_3", + ], + dtype=object, ) smote = SMOTENC(categorical_features=[3], k_neighbors=1, random_state=0) X_res, y_res = smote.fit_resample(data, labels) # check that the categorical feature is not random but correspond to the # categories seen in the minority class samples - assert X_res[-1, -1] == "C" + assert_array_equal(X_res[-3:, -1], np.array(["C", "C", "C"], dtype=object)) + assert smote.median_std_ == {"class_2": 0.0, "class_3": 0.0} def test_smotenc_categorical_encoder():