-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
…#593) * relax patsy dependency * deprecating patsy, adding formulaic * formulaic tests * formulaic pickling test * patsy deprecation version change
- Loading branch information
Showing
6 changed files
with
273 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
try: | ||
import formulaic | ||
except ImportError: | ||
from sklego.notinstalled import NotInstalledPackage | ||
|
||
formulaic = NotInstalledPackage("formulaic") | ||
|
||
from sklearn.base import BaseEstimator, TransformerMixin | ||
from sklearn.utils.validation import check_is_fitted | ||
|
||
|
||
class FormulaicTransformer(TransformerMixin, BaseEstimator): | ||
"""The `FormulaicTransformer` offers a method to select the right columns from a dataframe as well as a DSL for | ||
transformations. | ||
It is inspired from R formulas. This is can be useful as a first step in the pipeline. | ||
Parameters | ||
---------- | ||
formula : str | ||
A formulaic-compatible formula. | ||
Refer to the [formulaic documentation](https://matthewwardrop.github.io/formulaic/guides/grammar/) for more details. | ||
return_type : Literal["pandas", "numpy", "sparse"], default="numpy" | ||
The type of the returned matrix. | ||
Refer to the [formulaic documentation](https://matthewwardrop.github.io/formulaic/guides/model_specs/) for more details. | ||
Attributes | ||
---------- | ||
formula_ : formulaic.Formula | ||
The parsed formula specification. | ||
model_spec_ : formulaic.ModelSpec | ||
The parsed model specification. | ||
n_features_in_ : int | ||
Number of features seen during `fit`. | ||
""" | ||
|
||
def __init__(self, formula, return_type="numpy"): | ||
self.formula = formula | ||
self.return_type = return_type | ||
|
||
def fit(self, X, y=None): | ||
"""Fit the `FormulaicTransformer` to the data by compiling the formula specification into a model spec. | ||
Parameters | ||
---------- | ||
X : pd.DataFrame of (n_samples, n_features) | ||
The data used to compile model spec. | ||
y : array-like of shape (n_samples,), default=None | ||
Ignored, present for compatibility. | ||
Returns | ||
------- | ||
self : FormulaicTransformer | ||
The fitted transformer. | ||
Raises | ||
------ | ||
ValueError | ||
If `formula` is not supported. | ||
""" | ||
self.formula_ = formulaic.Formula.from_spec(self.formula) | ||
|
||
if self.formula_._has_structure: | ||
raise ValueError( | ||
f"Formula specification {repr(self.formula_)} results in a structured formula, which is not supported." | ||
) | ||
|
||
self.model_spec_ = self.formula_.get_model_matrix( | ||
X, output=self.return_type | ||
).model_spec | ||
self.n_features_in_ = X.shape[1] | ||
return self | ||
|
||
def transform(self, X, y=None): | ||
"""Transform `X` by generating a model matrix from it based on the fit model spec. | ||
Parameters | ||
---------- | ||
X : pd.DataFrame of shape (n_samples, n_features) | ||
The data for transformation will be applied. | ||
y: array-like of shape (n_samples,), default=None | ||
Ignored, present for compatibility. | ||
Returns | ||
------- | ||
X : array-like of shape (n_samples, n_features), and type `return_type` | ||
Transformed data. | ||
Raises | ||
------ | ||
ValueError | ||
If the number of columns from `X` differs from the number of columns when fitting. | ||
""" | ||
|
||
check_is_fitted(self, ["formula_", "model_spec_", "n_features_in_"]) | ||
|
||
if X.shape[1] != self.n_features_in_: | ||
raise ValueError( | ||
"`X` must have the same number of columns in fit and transform. " | ||
f"Expected {self.n_features_in_}, found {X.shape[1]}." | ||
) | ||
|
||
X_ = self.model_spec_.get_model_matrix(X) | ||
return X_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import pytest | ||
import joblib | ||
import numpy as np | ||
import pandas as pd | ||
from scipy.sparse import spmatrix | ||
from sklearn.preprocessing import StandardScaler | ||
from sklearn.pipeline import Pipeline | ||
from sklearn.linear_model import LogisticRegression | ||
|
||
from sklego.preprocessing import FormulaicTransformer | ||
|
||
|
||
@pytest.fixture() | ||
def df(): | ||
return pd.DataFrame( | ||
{ | ||
"a": [1, 2, 3, 4, 5, 6], | ||
"b": np.log([10, 9, 8, 7, 6, 5]), | ||
"c": ["a", "b", "a", "b", "c", "c"], | ||
"d": ["b", "a", "a", "b", "a", "b"], | ||
"e": [0, 1, 0, 1, 0, 1], | ||
} | ||
) | ||
|
||
@pytest.mark.parametrize( | ||
"return_type, expected_type", | ||
[ | ||
("numpy", np.ndarray), | ||
("pandas", pd.DataFrame), | ||
("sparse", spmatrix), | ||
], | ||
) | ||
def test_return_type(df, return_type, expected_type): | ||
X, y = df[["a", "b", "c", "d"]], df[["e"]] | ||
tf = FormulaicTransformer("a + b - 1", return_type=return_type) | ||
df_fit_transformed = tf.fit(X, y).transform(X) | ||
assert isinstance(df_fit_transformed, expected_type) | ||
|
||
|
||
|
||
@pytest.mark.parametrize( | ||
"formula, expected_shape", | ||
[ | ||
("a + b - 1", (6, 2)), | ||
("a + np.log(a) + b - 1", (6, 3)), | ||
("a*b - 1", (6, 3)), | ||
("a + b + d", (6,4)), | ||
("a + b + c + d", (6,6)), | ||
], | ||
) | ||
def test_formula_output(df, formula, expected_shape): | ||
X, y = df[["a", "b", "c", "d"]], df[["e"]] | ||
tf = FormulaicTransformer(formula=formula) | ||
|
||
assert tf.fit(X, y).transform(X).shape == expected_shape | ||
|
||
|
||
|
||
def test_pipeline(df): | ||
X, y = df[["a", "b", "c", "d"]], df[["e"]].values.ravel() | ||
|
||
pipe = Pipeline( | ||
[ | ||
("design", FormulaicTransformer("a + np.log(a) + b - 1")), | ||
("scale", StandardScaler()), | ||
("model", LogisticRegression(solver="lbfgs")), | ||
] | ||
) | ||
assert pipe.fit(X, y).predict(X).shape[0] == X.shape[0] | ||
|
||
|
||
def test_unseen_categories(df): | ||
df_train, df_test = df[:4], df[4:] | ||
|
||
X_train, y_train = df_train[["a", "b", "c", "d"]], df_train[["e"]].values.ravel() | ||
X_test = df_test[["a", "b", "c", "d"]] | ||
|
||
trf = FormulaicTransformer("a + np.log(a) + b + c + d - 1") | ||
_ = trf.fit(X_train, y_train) | ||
|
||
assert trf.transform(X_test).shape[1] == trf.transform(X_train).shape[1] | ||
|
||
pipe = Pipeline( | ||
[ | ||
("design", FormulaicTransformer("a + np.log(a) + b + c + d - 1")), | ||
("scale", StandardScaler()), | ||
("model", LogisticRegression(solver="lbfgs")), | ||
] | ||
) | ||
|
||
_ = pipe.fit(X_train, y_train) | ||
assert pipe.predict(X_test).shape[0] == X_test.shape[0] | ||
|
||
def test_misshape(df): | ||
df_train, df_test = df[:4], df[4:] | ||
|
||
X_train, y_train = df_train[["a", "b", "c", "d"]], df_train[["e"]].values.ravel() | ||
X_test = df_test[["a", "b", "c"]] | ||
|
||
trf = FormulaicTransformer("a + np.log(a) + b + c + d - 1") | ||
_ = trf.fit(X_train, y_train) | ||
|
||
with pytest.raises(ValueError): | ||
trf.transform(X_test) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"return_type", ("numpy", "pandas") | ||
) | ||
@pytest.mark.parametrize( | ||
"formula", ( | ||
"a + b - 1", | ||
"a + np.log(a) + b - 1", | ||
"a*b - 1", | ||
"a + b + d", | ||
"a + b + c + d", | ||
) | ||
) | ||
def test_pickling(tmp_path, df, return_type, formula): | ||
|
||
df_train, df_test = df[:4], df[4:] | ||
|
||
X_train, y_train = df_train[["a", "b", "c", "d"]], df_train[["e"]].values.ravel() | ||
X_test = df_test[["a", "b", "c", "d"]] | ||
|
||
pipe = Pipeline( | ||
[ | ||
("design", FormulaicTransformer(formula=formula, return_type=return_type)), | ||
("scale", StandardScaler()), | ||
("model", LogisticRegression(solver="lbfgs")), | ||
] | ||
) | ||
|
||
_ = pipe.fit(X_train, y_train) | ||
|
||
joblib.dump(pipe, tmp_path/"pipeline.pkl") | ||
loaded_pipe = joblib.load(tmp_path/"pipeline.pkl") | ||
|
||
assert loaded_pipe.predict(X_test).shape[0] == X_test.shape[0] |