Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
520c69f
implement InstanceHardnessCV
fritshermans Jan 31, 2025
f54f035
add documentation
fritshermans Feb 1, 2025
4b12063
add cross_validation.rst
fritshermans Feb 1, 2025
0e98d1f
fix plot_instance_hardness_cv.py
fritshermans Feb 1, 2025
2ceea7f
add initial documentation
fritshermans Feb 2, 2025
7ef85ff
add docstrings
fritshermans Feb 2, 2025
cc611e2
fix random seed in unit test
fritshermans Mar 26, 2025
f0c03bb
refactor the way groups are assigned by instance hardness in Instance…
fritshermans Mar 26, 2025
018df65
simplify plotting code in plot_instance_hardness_cv.py
fritshermans Mar 26, 2025
2fdca6f
update docstring
fritshermans Mar 26, 2025
0ce2eb3
update 'labels' to 'tick_labels' in boxplot code
fritshermans Mar 26, 2025
a394cf2
rename clf to estimator
fritshermans Mar 26, 2025
d06c580
change data generation in plot_instance_hardness_cv.py
fritshermans Mar 26, 2025
38509dd
describe InstanceHardnessCV in User Guide
fritshermans Mar 26, 2025
aded9e9
add x label to boxplot
fritshermans Mar 26, 2025
4647a2b
fix typo
fritshermans Mar 29, 2025
636dc5b
explain instance hardness in user guide
fritshermans Mar 29, 2025
1c642ce
remove default random forest as estimator for InstanceHardnessCV
fritshermans Mar 29, 2025
bf260c4
Merge remote-tracking branch 'origin' into pr/fritshermans/1125
glemaitre Aug 13, 2025
d544dc4
MAINT couple of fixes
glemaitre Aug 13, 2025
265f653
do not use mixin but implement functionality
glemaitre Aug 13, 2025
e9684c1
fix tests
glemaitre Aug 14, 2025
1946606
add entry in changelog
glemaitre Aug 14, 2025
b6d222d
fix docstring
glemaitre Aug 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions doc/model_selection.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
.. _cross_validation:

================
Cross validation
================

.. currentmodule:: imblearn.model_selection


.. _instance_hardness_threshold_cv:

The term instance hardness is used in literature to express the difficulty to correctly
classify an instance. An instance for which the predicted probability of the true class
is low, has large instance hardness. The way these hard-to-classify instances are
distributed over train and test sets in cross validation, has significant effect on the
test set performance metrics. The :class:`~imblearn.model_selection.InstanceHardnessCV`
splitter distributes samples with large instance hardness equally over the folds,
resulting in more robust cross validation.

We will discuss instance hardness in this document and explain how to use the
:class:`~imblearn.model_selection.InstanceHardnessCV` splitter.

Instance hardness and average precision
=======================================

Instance hardness is defined as 1 minus the probability of the most probable class:

.. math::

H(x) = 1 - P(\hat{y}|x)

In this equation :math:`H(x)` is the instance hardness for a sample with features
:math:`x` and :math:`P(\hat{y}|x)` the probability of predicted label :math:`\hat{y}`
given the features. If the model predicts label 0 and gives a `predict_proba` output
of [0.9, 0.1], the probability of the most probable class (0) is 0.9 and the
instance hardness is `1-0.9=0.1`.

Samples with large instance hardness have significant effect on the area under
precision-recall curve, or average precision. Especially samples with label 0
with large instance hardness (so the model predicts label 1) reduce the average
precision a lot as these points affect the precision-recall curve in the left
where the area is largest; the precision is lowered in the range of low recall
and high thresholds. When doing cross validation, e.g. in case of hyperparameter
tuning or recursive feature elimination, random gathering of these points in
some folds introduce variance in CV results that deteriorates robustness of the
cross validation task. The :class:`~imblearn.model_selection.InstanceHardnessCV`
splitter aims to distribute the samples with large instance hardness over the
folds in order to reduce undesired variance. Note that one should use this
splitter to make model *selection* tasks robust like hyperparameter tuning and
feature selection but not for model *performance estimation* for which you also
want to know the variance of performance to be expected in production.


Create imbalanced dataset with samples with large instance hardness
===================================================================

Let's start by creating a dataset to work with. We create a dataset with 5% class
imbalance using scikit-learn's :func:`~sklearn.datasets.make_blobs` function.

>>> import numpy as np
>>> from matplotlib import pyplot as plt
>>> from sklearn.datasets import make_blobs
>>> from imblearn.datasets import make_imbalance
>>> random_state = 10
>>> X, y = make_blobs(n_samples=[950, 50], centers=((-3, 0), (3, 0)),
... random_state=random_state)
>>> plt.scatter(X[:, 0], X[:, 1], c=y)
>>> plt.show()

.. image:: ./auto_examples/model_selection/images/sphx_glr_plot_instance_hardness_cv_001.png
:target: ./auto_examples/model_selection/plot_instance_hardness_cv.html
:align: center

Now we add some samples with large instance hardness

>>> X_hard, y_hard = make_blobs(n_samples=10, centers=((3, 0), (-3, 0)),
... cluster_std=1,
... random_state=random_state)
>>> X = np.vstack((X, X_hard))
>>> y = np.hstack((y, y_hard))
>>> plt.scatter(X[:, 0], X[:, 1], c=y)
>>> plt.show()

.. image:: ./auto_examples/model_selection/images/sphx_glr_plot_instance_hardness_cv_002.png
:target: ./auto_examples/model_selection/plot_instance_hardness_cv.html
:align: center

Assess cross validation performance variance using `InstanceHardnessCV` splitter
================================================================================

Then we take a :class:`~sklearn.linear_model.LogisticRegression` and assess the
cross validation performance using a :class:`~sklearn.model_selection.StratifiedKFold`
cv splitter and the :func:`~sklearn.model_selection.cross_validate` function.

>>> from sklearn.ensemble import LogisticRegressionClassifier
>>> clf = LogisticRegressionClassifier(random_state=random_state)
>>> skf_cv = StratifiedKFold(n_splits=5, shuffle=True,
... random_state=random_state)
>>> skf_result = cross_validate(clf, X, y, cv=skf_cv, scoring="average_precision")

Now, we do the same using an :class:`~imblearn.model_selection.InstanceHardnessCV`
splitter. We use provide our classifier to the splitter to calculate instance hardness
and distribute samples with large instance hardness equally over the folds.

>>> ih_cv = InstanceHardnessCV(estimator=clf, n_splits=5,
... random_state=random_state)
>>> ih_result = cross_validate(clf, X, y, cv=ih_cv, scoring="average_precision")

When we plot the test scores for both cv splitters, we see that the variance using the
:class:`~imblearn.model_selection.InstanceHardnessCV` splitter is lower than for the
:class:`~sklearn.model_selection.StratifiedKFold` splitter.

>>> plt.boxplot([skf_result['test_score'], ih_result['test_score']],
... tick_labels=["StratifiedKFold", "InstanceHardnessCV"],
... vert=False)
>>> plt.xlabel('Average precision')
>>> plt.tight_layout()

.. image:: ./auto_examples/model_selection/images/sphx_glr_plot_instance_hardness_cv_003.png
:target: ./auto_examples/model_selection/plot_instance_hardness_cv.html
:align: center

Be aware that the most important part of cross-validation splitters is to simulate the
conditions that one will encounter in production. Therefore, if it is likely to get
difficult samples in production, one should use a cross-validation splitter that
emulates this situation. In our case, the
:class:`~sklearn.model_selection.StratifiedKFold` splitter did not allow to distribute
the difficult samples over the folds and thus it was likely a problem for our use case.
1 change: 1 addition & 0 deletions doc/references/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ This is the full API documentation of the `imbalanced-learn` toolbox.
miscellaneous
pipeline
metrics
model_selection
datasets
utils
23 changes: 23 additions & 0 deletions doc/references/model_selection.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. _model_selection_ref:

Model selection methods
=======================

.. automodule:: imblearn.model_selection
:no-members:
:no-inherited-members:

Cross-validation splitters
--------------------------

.. automodule:: imblearn.model_selection._split
:no-members:
:no-inherited-members:

.. currentmodule:: imblearn.model_selection

.. autosummary::
:toctree: generated/
:template: class.rst

InstanceHardnessCV
1 change: 1 addition & 0 deletions doc/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ User Guide
ensemble.rst
miscellaneous.rst
metrics.rst
model_selection.rst
common_pitfalls.rst
Dataset loading utilities <datasets/index.rst>
developers_utils.rst
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new/0.14.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Bug fixes
Enhancements
............

- Add :class:`~imblearn.model_selection.InstanceHardnessCV` to split data and ensure
that samples are distributed in folds based on their instance hardness.
:pr:`1125` by :user:`Frits Hermans <fritshermans>`.

Compatibility
.............

Expand Down
97 changes: 97 additions & 0 deletions examples/model_selection/plot_instance_hardness_cv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
====================================================
Distribute hard-to-classify datapoints over CV folds
====================================================

'Instance hardness' refers to the difficulty to classify an instance. The way
hard-to-classify instances are distributed over train and test sets has
significant effect on the test set performance metrics. In this example we
show how to deal with this problem. We are making the comparison with normal
:class:`~sklearn.model_selection.StratifiedKFold` cross-validation splitter.
"""

# Authors: Frits Hermans, https://fritshermans.github.io
# License: MIT

# %%
print(__doc__)

# %%
# Create an imbalanced dataset with instance hardness
# ---------------------------------------------------
#
# We create an imbalanced dataset with using scikit-learn's
# :func:`~sklearn.datasets.make_blobs` function and set the class imbalance ratio to
# 5%.
import numpy as np
from matplotlib import pyplot as plt
from sklearn.datasets import make_blobs

X, y = make_blobs(n_samples=[950, 50], centers=((-3, 0), (3, 0)), random_state=10)
plt.scatter(X[:, 0], X[:, 1], c=y)

# %%
# To introduce instance hardness in our dataset, we add some hard to classify samples:
X_hard, y_hard = make_blobs(
n_samples=10, centers=((3, 0), (-3, 0)), cluster_std=1, random_state=10
)
X, y = np.vstack((X, X_hard)), np.hstack((y, y_hard))
plt.scatter(X[:, 0], X[:, 1], c=y)

# %%
# Compare cross validation scores using `StratifiedKFold` and `InstanceHardnessCV`
# --------------------------------------------------------------------------------
#
# Now, we want to assess a linear predictive model. Therefore, we should use
# cross-validation. The most important concept with cross-validation is to create
# training and test splits that are representative of the the data in production to have
# statistical results that one can expect in production.
#
# By applying a standard :class:`~sklearn.model_selection.StratifiedKFold`
# cross-validation splitter, we do not control in which fold the hard-to-classify
# samples will be.
#
# The :class:`~imblearn.model_selection.InstanceHardnessCV` splitter allows to
# control the distribution of the hard-to-classify samples over the folds.
#
# Let's make an experiment to compare the results that we get with both splitters.
# We use a :class:`~sklearn.linear_model.LogisticRegression` classifier and
# :func:`~sklearn.model_selection.cross_validate` to calculate the cross validation
# scores. We use average precision for scoring.
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold, cross_validate

from imblearn.model_selection import InstanceHardnessCV

logistic_regression = LogisticRegression()

results = {}
for cv in (
StratifiedKFold(n_splits=5, shuffle=True, random_state=10),
InstanceHardnessCV(estimator=LogisticRegression(), n_splits=5, random_state=10),
):
result = cross_validate(
logistic_regression,
X,
y,
cv=cv,
scoring="average_precision",
)
results[cv.__class__.__name__] = result["test_score"]
results = pd.DataFrame(results)

# %%
ax = results.plot.box(vert=False, whis=[0, 100])
ax.set(
xlabel="Average precision",
title="Cross validation scores with different splitters",
xlim=(0, 1),
)

# %%
# The boxplot shows that the :class:`~imblearn.model_selection.InstanceHardnessCV`
# splitter results in less variation of average precision than
# :class:`~sklearn.model_selection.StratifiedKFold` splitter. When doing
# hyperparameter tuning or feature selection using a wrapper method (like
# :class:`~sklearn.feature_selection.RFECV`) this will give more stable results.
6 changes: 5 additions & 1 deletion imblearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
Module which provides methods generating an ensemble of
under-sampled subsets.
exceptions
Module including custom warnings and error clases used across
Module including custom warnings and error classes used across
imbalanced-learn.
keras
Module which provides custom generator, layers for deep learning using
keras.
metrics
Module which provides metrics to quantified the classification performance
with imbalanced dataset.
model_selection
Module which provides methods to split the dataset into training and test sets.
over_sampling
Module which provides methods to over-sample a dataset.
tensorflow
Expand Down Expand Up @@ -54,6 +56,7 @@
ensemble,
exceptions,
metrics,
model_selection,
over_sampling,
pipeline,
tensorflow,
Expand Down Expand Up @@ -113,6 +116,7 @@ def __dir__(self):
"exceptions",
"keras",
"metrics",
"model_selection",
"over_sampling",
"tensorflow",
"under_sampling",
Expand Down
8 changes: 8 additions & 0 deletions imblearn/model_selection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
The :mod:`imblearn.model_selection` provides methods to split the dataset into
training and test sets.
"""

from ._split import InstanceHardnessCV

__all__ = ["InstanceHardnessCV"]
Loading
Loading