Skip to content

Commit

Permalink
make testing function private
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Jan 16, 2022
1 parent 615a2bf commit b75b77d
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
10 changes: 5 additions & 5 deletions imblearn/over_sampling/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
SMOTENC,
SVMSMOTE,
)
from imblearn.utils.testing import CustomNearestNeighbors
from imblearn.utils.testing import _CustomNearestNeighbors


@pytest.fixture
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_smote_m_neighbors(numerical_data, smote):
def test_numerical_smote_custom_nn(numerical_data, smote, neighbor_estimator_name):
X, y = numerical_data
params = {
neighbor_estimator_name: CustomNearestNeighbors(n_neighbors=5),
neighbor_estimator_name: _CustomNearestNeighbors(n_neighbors=5),
}
smote.set_params(**params)
X_res, _ = smote.fit_resample(X, y)
Expand All @@ -86,7 +86,7 @@ def test_numerical_smote_custom_nn(numerical_data, smote, neighbor_estimator_nam

def test_categorical_smote_k_custom_nn(categorical_data):
X, y = categorical_data
smote = SMOTEN(k_neighbors=CustomNearestNeighbors(n_neighbors=5))
smote = SMOTEN(k_neighbors=_CustomNearestNeighbors(n_neighbors=5))
X_res, y_res = smote.fit_resample(X, y)

assert X_res.shape == (80, 3)
Expand All @@ -96,7 +96,7 @@ def test_categorical_smote_k_custom_nn(categorical_data):
def test_heterogeneous_smote_k_custom_nn(heterogeneous_data):
X, y, categorical_features = heterogeneous_data
smote = SMOTENC(
categorical_features, k_neighbors=CustomNearestNeighbors(n_neighbors=5)
categorical_features, k_neighbors=_CustomNearestNeighbors(n_neighbors=5)
)
X_res, y_res = smote.fit_resample(X, y)

Expand All @@ -111,7 +111,7 @@ def test_heterogeneous_smote_k_custom_nn(heterogeneous_data):
)
def test_numerical_smote_extra_custom_nn(numerical_data, smote):
X, y = numerical_data
smote.set_params(m_neighbors=CustomNearestNeighbors(n_neighbors=5))
smote.set_params(m_neighbors=_CustomNearestNeighbors(n_neighbors=5))
X_res, y_res = smote.fit_resample(X, y)

assert X_res.shape == (120, 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sklearn.datasets import make_classification

from imblearn.under_sampling import ClusterCentroids
from imblearn.utils.testing import CustomClusterer
from imblearn.utils.testing import _CustomClusterer

RND_SEED = 0
X = np.array(
Expand Down Expand Up @@ -170,4 +170,4 @@ def test_cluster_centroids_error_estimator():
"`cluster_centers_`."
)
with pytest.raises(RuntimeError, match=err_msg):
ClusterCentroids(estimator=CustomClusterer()).fit_resample(X, Y)
ClusterCentroids(estimator=_CustomClusterer()).fit_resample(X, Y)
4 changes: 2 additions & 2 deletions imblearn/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def warns(expected_warning, match=None):
pass


class CustomNearestNeighbors(BaseEstimator):
class _CustomNearestNeighbors(BaseEstimator):
"""Basic implementation of nearest neighbors not relying on scikit-learn.
`kneighbors_graph` is ignored and `metric` does not have any impact.
Expand Down Expand Up @@ -197,7 +197,7 @@ def kneighbors_graph(X=None, n_neighbors=None, mode="connectivity"):
pass


class CustomClusterer(BaseEstimator):
class _CustomClusterer(BaseEstimator):
"""Class that mimics a cluster that does not expose `cluster_centers_`."""

def __init__(self, n_clusters=1):
Expand Down
4 changes: 2 additions & 2 deletions imblearn/utils/tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sklearn.neighbors._base import KNeighborsMixin

from imblearn.base import SamplerMixin
from imblearn.utils.testing import all_estimators, CustomNearestNeighbors
from imblearn.utils.testing import all_estimators, _CustomNearestNeighbors

from imblearn.utils.testing import warns

Expand Down Expand Up @@ -69,7 +69,7 @@ def test_custom_nearest_neighbors():
"""Check that our custom nearest neighbors can be used for our internal
duck-typing."""

neareat_neighbors = CustomNearestNeighbors(n_neighbors=3)
neareat_neighbors = _CustomNearestNeighbors(n_neighbors=3)

assert not isinstance(neareat_neighbors, KNeighborsMixin)
assert hasattr(neareat_neighbors, "kneighbors")
Expand Down

0 comments on commit b75b77d

Please sign in to comment.