Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GroupedPredictor patch #619

Merged
merged 6 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/api/meta.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
show_root_full_path: true
show_root_heading: true

::: sklego.meta.grouped_predictor.GroupedClassifier
options:
show_root_full_path: true
show_root_heading: true

::: sklego.meta.grouped_predictor.GroupedRegressor
options:
show_root_full_path: true
show_root_heading: true

::: sklego.meta.grouped_transformer.GroupedTransformer
options:
show_root_full_path: true
Expand Down
8 changes: 2 additions & 6 deletions docs/rstudio.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ ggplot(data=cv_df) +
theme(legend.position="bottom")
```

<p align="center">
<img src="../_static/rstudio/Rplot1.png" />
</p>
![rplot1](_static/rstudio/Rplot1.png)

```r
ggplot(data=cv_df) +
Expand All @@ -122,9 +120,7 @@ ggplot(data=cv_df) +
theme(legend.position="bottom")
```

<p align="center">
<img src="../_static/rstudio/Rplot2.png" />
</p>
![rplot2](_static/rstudio/Rplot2.png)

## Important

Expand Down
4 changes: 1 addition & 3 deletions docs/user-guide/fairness.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ It does this by projecting all vectors away such that the remaining dataset is o

The [`InformationFilter`][filter-information-api] uses a variant of the [Gram–Schmidt process][gram–schmidt-process] to filter information out of the dataset. We can make it visual in two dimensions;

<p align="center">
<img src="../_static/fairness/projections.png" />
</p>
![projections](../_static/fairness/projections.png)

To explain what occurs in higher dimensions we need to resort to maths. Take a training matrix $X$ that contains columns $x_1, ..., x_k$.

Expand Down
25 changes: 18 additions & 7 deletions docs/user-guide/meta-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,13 @@ Wall time: 917 ms

## Grouped Prediction

<p align="center">
<img src="../_static/meta-models/grouped-model.png" />
</p>
![grouped-model](../_static/meta-models/grouped-model.png)
Copy link
Owner

@koaning koaning Feb 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to mention the classifier/regressor objects in the docs here maybe?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable


To help explain what it can do we'll consider three methods to predict the chicken weight.

The chicken data has 578 rows and 4 columns from an experiment on the effect of diet on early growth of chicks.
The body weights of the chicks were measured at birth and every second day thereafter until day 20. They were also measured on day 21.
The body weights of the chicks were measured at birth and every second day thereafter until day 20.
They were also measured on day 21.
There were four groups on chicks on different protein diets.

### Setup
Expand Down Expand Up @@ -94,13 +93,13 @@ So let's see how the grouped model can address this.

### Model 2: Linear Regression in GroupedPredictor

The goal of the [GroupedPredictor][grouped-predictor-api] is to allow us to split up our data.
The goal of the [`GroupedPredictor`][grouped-predictor-api] is to allow us to split up our data.

The image below demonstrates what will happen.

![grouped](../_static/meta-models/grouped-df.png)

We train 5 models in total because the model will also train a fallback automatically (you can turn this off via `use_fallback=False`).
We train 5 models in total because the model will also train a fallback automatically (you can turn this off via `use_global_model=False`).

The idea behind the fallback is that we can predict something if there is a group at prediction time which is unseen during training.

Expand All @@ -118,7 +117,7 @@ Such model looks a bit better.

### Model 3: Dummy Regression in GroupedEstimation

We could go a step further and train a [DummyRegressor][dummy-regressor-api] per diet per timestep.
We could go a step further and train a [`DummyRegressor`][dummy-regressor-api] per diet per timestep.

The code below works similar as the previous example but one difference is that the grouped model does not receive a dataframe but a numpy array.

Expand All @@ -135,6 +134,16 @@ The code that does this is listed below.

Note that these predictions seems to yield the lowest error but take it with a grain of salt since these errors are only based on the train set.

### Specialized Estimators

!!! info "New in version 0.7.5"

Instead of using the generic `GroupedPredictor` directly, it is possible to work with _task specific_ estimators, namely: [`GroupedClassifier`][grouped-classifier-api] and [`GroupedRegressor`][grouped-regressor-api].

Their specs and functionalities are the exact same of the `GroupedPredictor`[^1] but they are specialized for classification and regression tasks, respectively, by adding checks on the input estimator.

[^1]: Not entirely true, as `GroupedClassifier` doesn't allow for the `shrinkage` parameter.

## Grouped Transformation

We can apply grouped prediction on estimators that have a `.predict()` implemented but we're also able to do something similar for transformers, like `StandardScaler`.
Expand Down Expand Up @@ -443,6 +452,8 @@ As a meta-estimator, the `OrdinalClassifier` fits N-1 binary classifiers, which

[thresholder-api]: ../../api/meta#sklego.meta.thresholder.Thresholder
[grouped-predictor-api]: ../../api/meta#sklego.meta.grouped_predictor.GroupedPredictor
[grouped-classifier-api]: ../../api/meta#sklego.meta.grouped_predictor.GroupedClassifier
[grouped-regressor-api]: ../../api/meta#sklego.meta.grouped_predictor.GroupedRegressor
[grouped-transformer-api]: ../../api/meta#sklego.meta.grouped_transformer.GroupedTransformer
[decay-api]: ../../api/meta#sklego.meta.decay_estimator.DecayEstimator
[decay-functions]: ../../api/decay-functions
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ sklego = ["data/*.zip"]

[tool.ruff]
line-length = 120
extend-select = ["I"]
exclude = ["docs"]

[tool.ruff.lint]
extend-select = ["I"]

[tool.pytest.ini_options]
markers = [
"cvxpy: tests that require cvxpy (deselect with '-m \"not cvxpy\"')"
Expand Down
6 changes: 3 additions & 3 deletions sklego/meta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
"ConfusionBalancer",
"DecayEstimator",
"EstimatorTransformer",
"GroupedEstimator",
"GroupedClassifier",
"GroupedPredictor",
"GroupedRegressor",
"GroupedTransformer",
"OrdinalClassifier",
"OutlierRemover",
Expand All @@ -17,8 +18,7 @@
from sklego.meta.confusion_balancer import ConfusionBalancer
from sklego.meta.decay_estimator import DecayEstimator
from sklego.meta.estimator_transformer import EstimatorTransformer
from sklego.meta.grouped_estimator import GroupedEstimator
from sklego.meta.grouped_predictor import GroupedPredictor
from sklego.meta.grouped_predictor import GroupedClassifier, GroupedPredictor, GroupedRegressor
from sklego.meta.grouped_transformer import GroupedTransformer
from sklego.meta.ordinal_classification import OrdinalClassifier
from sklego.meta.outlier_classifier import OutlierClassifier
Expand Down
12 changes: 0 additions & 12 deletions sklego/meta/grouped_estimator.py

This file was deleted.

132 changes: 128 additions & 4 deletions sklego/meta/grouped_predictor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pandas as pd
from sklearn import clone
from sklearn.base import BaseEstimator, is_classifier
from sklearn.base import BaseEstimator, ClassifierMixin, MetaEstimatorMixin, RegressorMixin, is_classifier, is_regressor
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import check_array, check_is_fitted

Expand All @@ -14,9 +14,21 @@
)


class GroupedPredictor(BaseEstimator):
"""Construct an estimator per data group. Splits data by values of a single column and fits one estimator per such
column.
class GroupedPredictor(MetaEstimatorMixin, BaseEstimator):
"""`GroupedPredictor` is a meta-estimator that fits a separate estimator for each group in the input data.

The input data is split into a group and a value part: for each unique combination of the group columns, a separate
estimator is fitted to the corresponding value rows. The group columns are specified by the `groups` parameter.

If `use_global_model=True` a fallback estimator will be fitted on the entire dataset in case a group is not found
during `.predict()`.

If `shrinkage` is not `None`, the predictions of the group-level models are combined using a shrinkage method. The
shrinkage method can be one of the predefined methods `"constant"`, `"min_n_obs"`, `"relative"` or a custom
shrinkage function. The shrinkage method is specified by the `shrinkage` parameter.

!!! warning "Shrinkage"
Shrinkage is only available for regression models.

Parameters
----------
Expand All @@ -43,6 +55,19 @@ class GroupedPredictor(BaseEstimator):
If disabled, the model/pipeline is expected to handle e.g. missing, non-numeric, or non-finite values.
**shrinkage_kwargs : dict
Keyword arguments to the shrinkage function

Attributes
----------
estimators_ : dict
A dictionary with the fitted estimators per group
groups_ : list
A list of all the groups that were found during fitting
fallback_ : estimator
A fallback estimator that is used when `use_global_model=True` and a group is not found during `.predict()`
shrinkage_function_ : callable
The shrinkage function that is used to calculate the shrinkage factors
shrinkage_factors_ : dict
A dictionary with the shrinkage factors per group
"""

# Number of features in value df can be 0, e.g. for dummy models
Expand Down Expand Up @@ -212,6 +237,9 @@ def fit(self, X, y=None):
self : GroupedPredictor
The fitted estimator.
"""
if self.shrinkage is not None and not is_regressor(self.estimator):
raise ValueError("Shrinkage is only available for regression models")

X_group, X_value = _split_groups_and_values(
X, self.groups, min_value_cols=0, check_X=self.check_X, **self._check_kwargs
)
Expand Down Expand Up @@ -409,3 +437,99 @@ def decision_function(self, X):
return self.__predict_groups(X_group, X_value, method="decision_function")
else:
return self.__predict_shrinkage_groups(X_group, X_value, method="decision_function")

@property
def _estimator_type(self):
"""Computes `_estimator_type` dynamically from the wrapped model."""
return self.estimator._estimator_type


class GroupedRegressor(GroupedPredictor, RegressorMixin):
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
"""`GroupedRegressor` is a meta-estimator that fits a separate regressor for each group in the input data.

Its spec is the same as [`GroupedPredictor`][sklego.meta.grouped_predictor.GroupedPredictor] but it is available
only for regression models.

!!! info "New in version 0.7.5"
"""

def fit(self, X, y):
"""Fit one regressor for each group of training data `X` and `y`.

Will also learn the groups that exist within the training dataset.

Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data.
y : array-like of shape (n_samples,)
Target values.

Returns
-------
self : GroupedRegressor
The fitted regressor.

Raises
-------
ValueError
If the supplied estimator is not a regressor.
"""
if not is_regressor(self.estimator):
raise ValueError("GroupedRegressor is only available for regression models")

return super().fit(X, y)


class GroupedClassifier(GroupedPredictor, ClassifierMixin):
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
"""`GroupedClassifier` is a meta-estimator that fits a separate classifier for each group in the input data.

Its equivalent to [`GroupedPredictor`][sklego.meta.grouped_predictor.GroupedPredictor] with `shrinkage=None`
but it is available only for classification models.

!!! info "New in version 0.7.5"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you feel about this?

Proof of rendering:
image

"""

def __init__(
self,
estimator,
groups,
use_global_model=True,
check_X=True,
**shrinkage_kwargs,
):
super().__init__(
estimator=estimator,
groups=groups,
shrinkage=None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forcing shrinkage to be None

use_global_model=use_global_model,
check_X=check_X,
)

def fit(self, X, y):
"""Fit one classifier for each group of training data `X` and `y`.

Will also learn the groups that exist within the training dataset.

Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data.
y : array-like of shape (n_samples,)
Target values.

Returns
-------
self : GroupedClassifier
The fitted regressor.

Raises
-------
ValueError
If the supplied estimator is not a classifier.
"""

if not is_classifier(self.estimator):
raise ValueError("GroupedClassifier is only available for classification models")
self.classes_ = np.unique(y)
return super().fit(X, y)
Loading