Skip to content

Commit

Permalink
Merge pull request #582 from fabioscantamburlo/fabioscantamburlopatch…
Browse files Browse the repository at this point in the history
…/579-grouped-predictor-classifier

Labels fix in `GroupedPredictor.predict_proba`
resolves #579
  • Loading branch information
FBruzzesi authored Oct 12, 2023
2 parents 6f17125 + 3326b8a commit d5ea509
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 12 deletions.
17 changes: 13 additions & 4 deletions sklego/meta/grouped_predictor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pandas as pd
from sklearn import clone
from sklearn.base import BaseEstimator
from sklearn.base import BaseEstimator, is_classifier
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import (
check_is_fitted,
Expand Down Expand Up @@ -282,6 +282,7 @@ def __predict_single_group(self, group, X, method="predict"):
"""Predict a single group by getting its estimator from the fitted dict"""
# Keep track of the original index such that we can sort in __predict_groups
index = X.index

try:
group_predictor = self.estimators_[group]
except KeyError:
Expand All @@ -292,9 +293,15 @@ def __predict_single_group(self, group, X, method="predict"):
f"Found new group {group} during predict with use_global_model = False"
)

is_predict_proba = is_classifier(group_predictor) and method == "predict_proba"
# Ensure to provide pd.DataFrame with the correct label name
extra_kwargs = {"columns": group_predictor.classes_} if is_predict_proba else {}

# getattr(group_predictor, method) returns the predict method of the fitted model
# if the method argument is "predict" and the predict_proba method if method argument is "predict_proba"
return pd.DataFrame(getattr(group_predictor, method)(X)).set_index(index)
return pd.DataFrame(
getattr(group_predictor, method)(X), **extra_kwargs
).set_index(index)

def __predict_groups(
self,
Expand Down Expand Up @@ -324,6 +331,8 @@ def __predict_groups(
],
axis=0,
)
# Fill with prob = 0 for impossible labels in predict_proba
.fillna(0)
.sort_index()
.values.squeeze()
)
Expand All @@ -350,7 +359,7 @@ def predict(self, X):
return self.__predict_shrinkage_groups(X_group, X_value, method="predict")

# This ensures that the meta-estimator only has the predict_proba method if the estimator has it
@available_if(lambda self: hasattr(self.estimator, 'predict_proba'))
@available_if(lambda self: hasattr(self.estimator, "predict_proba"))
def predict_proba(self, X):
"""
Predict probabilities on new data.
Expand All @@ -375,7 +384,7 @@ def predict_proba(self, X):
)

# This ensures that the meta-estimator only has the predict_proba method if the estimator has it
@available_if(lambda self: hasattr(self.estimator, 'decision_function'))
@available_if(lambda self: hasattr(self.estimator, "decision_function"))
def decision_function(self, X):
"""
Evaluate the decision function for the samples in X.
Expand Down
109 changes: 101 additions & 8 deletions tests/test_meta/test_grouped_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,25 @@
from tests.conftest import general_checks, select_tests


@pytest.fixture
def random_xy_grouped_clf_different_classes(request):
group_size = request.param.get("group_size")
y_choices_grpa = request.param.get("y_choices_grpa")
y_choices_grpb = request.param.get("y_choices_grpb")

np.random.seed(43)
group_col = np.repeat(["A", "B"], group_size)
x_col = np.random.normal(size=group_size * 2)
y_col = np.hstack(
[
np.random.choice(y_choices_grpa, size=group_size),
np.random.choice(y_choices_grpb, size=group_size),
]
)
df = pd.DataFrame({"group": group_col, "x": x_col, "y": y_col})
return df


@pytest.mark.parametrize(
"test_fn",
select_tests(
Expand Down Expand Up @@ -73,6 +92,79 @@ def test_chickweight_can_do_fallback_proba():
assert (mod.predict_proba(to_predict)[0] == mod.predict_proba(to_predict)[1]).all()


@pytest.mark.parametrize(
"random_xy_grouped_clf_different_classes",
[
{"group_size": 10, "y_choices_grpa": [0, 1, 2], "y_choices_grpb": [0, 1, 2, 4]},
{"group_size": 10, "y_choices_grpa": [0, 2], "y_choices_grpb": [0, 2]},
{"group_size": 10, "y_choices_grpa": [0, 1, 2, 3], "y_choices_grpb": [0, 4]},
{"group_size": 10, "y_choices_grpa": [0, 1, 2], "y_choices_grpb": [0, 3]},
],
indirect=True,
)
def test_predict_proba_has_same_columns_as_distinct_labels(
random_xy_grouped_clf_different_classes,
):
mod = GroupedPredictor(estimator=LogisticRegression(), groups="group")
X, y = (
random_xy_grouped_clf_different_classes[["group", "x"]],
random_xy_grouped_clf_different_classes["y"],
)
_ = mod.fit(X, y)
y_proba = mod.predict_proba(X)

# Ensure the number of col output is always equal to the cardinality of the labels
assert (
len(random_xy_grouped_clf_different_classes["y"].unique()) == y_proba.shape[1]
)


@pytest.mark.parametrize(
"random_xy_grouped_clf_different_classes",
[
{"group_size": 5, "y_choices_grpa": [0, 1, 2], "y_choices_grpb": [0, 2]},
],
indirect=True,
)
def test_predict_proba_correct_zeros_same_and_different_labels(
random_xy_grouped_clf_different_classes,
):
mod = GroupedPredictor(estimator=LogisticRegression(), groups="group")

X, y = (
random_xy_grouped_clf_different_classes[["group", "x"]],
random_xy_grouped_clf_different_classes["y"],
)
_ = mod.fit(X, y)
y_proba = mod.predict_proba(X)

df_proba = pd.concat(
[random_xy_grouped_clf_different_classes["group"], pd.DataFrame(y_proba)],
axis=1,
)

# Take distinct labels for group A and group B
labels_a, labels_b = (
random_xy_grouped_clf_different_classes.groupby("group")
.agg({"y": set})
.sort_index()["y"]
)

# Ensure for the common labels there are no zeros
in_common_labels = labels_a.intersection(labels_b)
assert all((df_proba.loc[:, label] != 0).all() for label in in_common_labels)

# Ensure for the non common labels there are only zeros
label_not_in_group = {
"A": list(labels_b.difference(labels_a)),
"B": list(labels_a.difference(labels_b)),
}
for grp_name, grp in df_proba.groupby("group"):
assert all(
(grp.loc[:, label] == 0).all() for label in label_not_in_group[grp_name]
)


def test_fallback_can_raise_error():
df = load_chicken(as_frame=True)
mod = GroupedPredictor(
Expand Down Expand Up @@ -117,7 +209,6 @@ def test_chickweight_np_keys():


def test_chickweigt_string_groups():

df = load_chicken(as_frame=True)
df["diet"] = ["omgomgomg" + s for s in df["diet"].astype(str)]

Expand Down Expand Up @@ -545,25 +636,27 @@ def test_bad_shrinkage_value_error():
def test_missing_check():
df = load_chicken(as_frame=True)

X, y = df.drop(columns='weight'), df['weight']
X, y = df.drop(columns="weight"), df["weight"]
# create missing value
X.loc[0, 'chick'] = np.nan
model = make_pipeline(SimpleImputer(), LinearRegression())
X.loc[0, "chick"] = np.nan
model = make_pipeline(SimpleImputer(), LinearRegression())

# Should not raise error, check is disabled
m = GroupedPredictor(model, groups = ['diet'], check_X = False).fit(X, y)
m = GroupedPredictor(model, groups=["diet"], check_X=False).fit(X, y)
m.predict(X)

# Should raise error, check is still enabled
with pytest.raises(ValueError) as e:
GroupedPredictor(model, groups = ['diet']).fit(X, y)
GroupedPredictor(model, groups=["diet"]).fit(X, y)
assert "contains NaN" in str(e)


def test_has_decision_function():
# needed as for example cross_val_score(pipe, X, y, cv=5, scoring="roc_auc", error_score='raise') may fail otherwise, see https://github.com/koaning/scikit-lego/issues/511
df = load_chicken(as_frame=True)

X, y = df.drop(columns='weight'), df['weight']
X, y = df.drop(columns="weight"), df["weight"]
# This should NOT raise errors
GroupedPredictor(LogisticRegression(max_iter=2000), groups=["diet"]).fit(X, y).decision_function(X)
GroupedPredictor(LogisticRegression(max_iter=2000), groups=["diet"]).fit(
X, y
).decision_function(X)

0 comments on commit d5ea509

Please sign in to comment.