Skip to content

Commit

Permalink
Merge branch 'main' into fix/lowess-better-error-message
Browse files Browse the repository at this point in the history
  • Loading branch information
koaning authored Nov 13, 2024
2 parents 460355b + 3d51e7b commit 25c4052
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 12 deletions.
42 changes: 30 additions & 12 deletions sklego/meta/zero_inflated_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class ZeroInflatedRegressor(BaseEstimator, RegressorMixin, MetaEstimatorMixin):
`ZeroInflatedRegressor` consists of a classifier and a regressor.
- The classifier's task is to find of if the target is zero or not.
- The regressor's task is to output a (usually positive) prediction whenever the classifier indicates that the
- The classifier's task is to find if the target is zero or not.
- The regressor's task is to output a (usually positive) prediction whenever the classifier indicates that
there should be a non-zero prediction.
The regressor is only trained on examples where the target is non-zero, which makes it easier for it to focus.
Expand All @@ -29,6 +29,11 @@ class ZeroInflatedRegressor(BaseEstimator, RegressorMixin, MetaEstimatorMixin):
regressor : scikit-learn compatible regressor
A regressor for predicting the target. Its prediction is only used if `classifier` says that the output is
non-zero.
handle_zero : Literal["error", "ignore"], default="error"
How to behave in the case that all train set output consists of zero values only.
- `handle_zero = 'error'`: will raise a `ValueError` (default).
- `handle_zero = 'ignore'`: will continue to train the regressor on the entire dataset.
Attributes
----------
Expand Down Expand Up @@ -63,9 +68,10 @@ class ZeroInflatedRegressor(BaseEstimator, RegressorMixin, MetaEstimatorMixin):

_required_parameters = ["classifier", "regressor"]

def __init__(self, classifier, regressor) -> None:
def __init__(self, classifier, regressor, handle_zero="error") -> None:
self.classifier = classifier
self.regressor = regressor
self.handle_zero = handle_zero

def fit(self, X, y, sample_weight=None):
"""Fit the underlying classifier and regressor using `X` and `y` as training data. The regressor is only trained
Expand All @@ -88,7 +94,9 @@ def fit(self, X, y, sample_weight=None):
Raises
------
ValueError
If `classifier` is not a classifier or `regressor` is not a regressor.
If `classifier` is not a classifier
If `regressor` is not a regressor
If all train target entirely consists of zeros and `handle_zero="error"`
"""
X, y = check_X_y(X, y)
self._check_n_features(X, reset=True)
Expand All @@ -98,6 +106,10 @@ def fit(self, X, y, sample_weight=None):
)
if not is_regressor(self.regressor):
raise ValueError(f"`regressor` has to be a regressor. Received instance of {type(self.regressor)} instead.")
if self.handle_zero not in {"ignore", "error"}:
raise ValueError(
f"`handle_zero` has to be one of {'ignore', 'error'}. Received '{self.handle_zero}' instead."
)

sample_weight = _check_sample_weight(sample_weight, X)
try:
Expand All @@ -112,9 +124,14 @@ def fit(self, X, y, sample_weight=None):
logging.warning("Classifier ignores sample_weight.")
self.classifier_.fit(X, y != 0)

non_zero_indices = np.where(y != 0)[0]
indices_for_training = np.where(y != 0)[0] # these are the non-zero indices
if (self.handle_zero == "ignore") & (
indices_for_training.size == 0
): # if we choose to ignore that all train set output is 0
logging.warning("Regressor will be training on `y` consisting of zero values only.")
indices_for_training = np.where(y == 0)[0] # use the whole train set

if non_zero_indices.size > 0:
if indices_for_training.size > 0:
try:
check_is_fitted(self.regressor)
self.regressor_ = self.regressor
Expand All @@ -123,20 +140,21 @@ def fit(self, X, y, sample_weight=None):

if "sample_weight" in signature(self.regressor_.fit).parameters:
self.regressor_.fit(
X[non_zero_indices],
y[non_zero_indices],
sample_weight=sample_weight[non_zero_indices] if sample_weight is not None else None,
X[indices_for_training],
y[indices_for_training],
sample_weight=sample_weight[indices_for_training] if sample_weight is not None else None,
)
else:
logging.warning("Regressor ignores sample_weight.")
self.regressor_.fit(
X[non_zero_indices],
y[non_zero_indices],
X[indices_for_training],
y[indices_for_training],
)
else:
raise ValueError(
"""The predicted training labels are all zero, making the regressor obsolete. Change the classifier
or use a plain regressor instead."""
or use a plain regressor instead. Alternatively, you can choose to ignore that predicted labels are
all zero by setting flag handle_zero = 'ignore'"""
)

return self
Expand Down
39 changes: 39 additions & 0 deletions tests/test_meta/test_zero_inflated_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,24 @@ def test_zero_inflated_with_sample_weights_example(classifier, regressor, perfor
assert zir_score > performance


def test_zero_inflated_with_handle_zero_ignore_example():
"""Test that if handle_zero='ignore' and all y are 0, no Exception will be thrown"""

np.random.seed(0)
size = 1_000
X = np.random.randn(size, 4)
y = np.zeros(size) # all outputs are 0

zir = ZeroInflatedRegressor(
classifier=ExtraTreesClassifier(max_depth=20, random_state=0, n_jobs=-1),
regressor=ExtraTreesRegressor(max_depth=20, random_state=0, n_jobs=-1),
handle_zero="ignore",
).fit(X, y)

# The predicted values should all be 0
assert (zir.predict(X) == np.zeros(size)).all()


def test_wrong_estimators_exceptions():
X = np.array([[0.0]])
y = np.array([0.0])
Expand All @@ -83,6 +101,27 @@ def test_wrong_estimators_exceptions():
zir = ZeroInflatedRegressor(ExtraTreesClassifier(), ExtraTreesClassifier())
zir.fit(X, y)

with pytest.raises(
ValueError, match="`handle_zero` has to be one of \('ignore', 'error'\). Received 'ignor' instead."
):
zir = ZeroInflatedRegressor(
classifier=ExtraTreesClassifier(max_depth=20, random_state=0, n_jobs=-1),
regressor=ExtraTreesRegressor(max_depth=20, random_state=0, n_jobs=-1),
handle_zero="ignor",
)
zir.fit(X, y)

error_text = """The predicted training labels are all zero, making the regressor obsolete\. Change the classifier
or use a plain regressor instead\. Alternatively, you can choose to ignore that predicted labels are
all zero by setting flag handle_zero = 'ignore'"""

with pytest.raises(ValueError, match=error_text):
zir = ZeroInflatedRegressor(
classifier=ExtraTreesClassifier(max_depth=20, random_state=0, n_jobs=-1),
regressor=ExtraTreesRegressor(max_depth=20, random_state=0, n_jobs=-1),
)
zir.fit(X, y) # default is handle_zero = 'error'


def approx_lte(x, y):
return ((x <= y) | np.isclose(x, y)).all()
Expand Down

0 comments on commit 25c4052

Please sign in to comment.