Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v0.11.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ Bug fixes
since it requires a conversion to dense matrices.
:pr:`1003` by :user:`Guillaume Lemaitre <glemaitre>`.

- Remove spurious warning raised when minority class get over-sampled more than the
number of sample in the majority class.
:pr:`1007` by :user:`Guillaume Lemaitre <glemaitre>`.

Compatibility
.............

Expand Down
11 changes: 2 additions & 9 deletions imblearn/utils/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ def _sampling_strategy_dict(sampling_strategy, y, sampling_type):
)
sampling_strategy_ = {}
if sampling_type == "over-sampling":
n_samples_majority = max(target_stats.values())
class_majority = max(target_stats, key=target_stats.get)
max(target_stats.values())
max(target_stats, key=target_stats.get)
for class_sample, n_samples in sampling_strategy.items():
if n_samples < target_stats[class_sample]:
raise ValueError(
Expand All @@ -318,13 +318,6 @@ def _sampling_strategy_dict(sampling_strategy, y, sampling_type):
f" Originally, there is {target_stats[class_sample]} "
f"samples and {n_samples} samples are asked."
)
if n_samples > n_samples_majority:
warnings.warn(
f"After over-sampling, the number of samples ({n_samples})"
f" in class {class_sample} will be larger than the number of"
f" samples in the majority class (class #{class_majority} ->"
f" {n_samples_majority})"
)
sampling_strategy_[class_sample] = n_samples - target_stats[class_sample]
elif sampling_type == "under-sampling":
for class_sample, n_samples in sampling_strategy.items():
Expand Down
8 changes: 0 additions & 8 deletions imblearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,6 @@ def test_check_sampling_strategy(
assert sampling_strategy_ == expected_sampling_strategy


def test_sampling_strategy_dict_over_sampling():
y = np.array([1] * 50 + [2] * 100 + [3] * 25)
sampling_strategy = {1: 70, 2: 140, 3: 70}
expected_msg = "After over-sampling, the number of samples "
with pytest.warns(UserWarning, match=expected_msg):
check_sampling_strategy(sampling_strategy, y, "over-sampling")


def test_sampling_strategy_callable_args():
y = np.array([1] * 50 + [2] * 100 + [3] * 25)
multiplier = {1: 1.5, 2: 1, 3: 3}
Expand Down