Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions doc/whats_new/v0.12.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
.. _changes_0_12:

Version 0.12.1
==============

**In progress**

Changelog
---------

Bug fixes
.........

- Fix a bug in :class:`~imblearn.under_sampling.InstanceHardnessThreshold` where
`estimator` could not be a :class:`~sklearn.pipeline.Pipeline` object.
:pr:`1049` by :user:`Gonenc Mogol <gmogol>`.

Version 0.12.0
==============

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections import Counter

import numpy as np
from sklearn.base import ClassifierMixin, clone
from sklearn.base import clone, is_classifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble._base import _set_random_states
from sklearn.model_selection import StratifiedKFold, cross_val_predict
Expand Down Expand Up @@ -140,7 +140,7 @@ def _validate_estimator(self, random_state):

if (
self.estimator is not None
and isinstance(self.estimator, ClassifierMixin)
and is_classifier(self.estimator)
and hasattr(self.estimator, "predict_proba")
):
self.estimator_ = clone(self.estimator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.naive_bayes import GaussianNB as NB
from sklearn.pipeline import make_pipeline
from sklearn.utils._testing import assert_array_equal

from imblearn.under_sampling import InstanceHardnessThreshold
Expand Down Expand Up @@ -93,3 +94,19 @@ def test_iht_fit_resample_default_estimator():
assert isinstance(iht.estimator_, RandomForestClassifier)
assert X_resampled.shape == (12, 2)
assert y_resampled.shape == (12,)


def test_iht_estimator_pipeline():
"""Check that we can pass a pipeline containing a classifier.

Checking if we have a classifier should not be based on inheriting from
`ClassifierMixin`.

Non-regression test for:
https://github.com/scikit-learn-contrib/imbalanced-learn/pull/1049
"""
model = make_pipeline(GradientBoostingClassifier(random_state=RND_SEED))
iht = InstanceHardnessThreshold(estimator=model, random_state=RND_SEED)
X_resampled, y_resampled = iht.fit_resample(X, Y)
assert X_resampled.shape == (12, 2)
assert y_resampled.shape == (12,)