Skip to content

Commit

Permalink
added alias for estimator type
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Sep 24, 2023
1 parent 69a77d7 commit b97f70a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
rev: v4.4.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
Expand All @@ -9,7 +9,7 @@ repos:
- id: check-merge-conflict
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: stable
rev: 23.9.1
hooks:
- id: black
files: sklego
5 changes: 5 additions & 0 deletions sklego/meta/decay_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def _is_classifier(self):
["ClassifierMixin" in p.__name__ for p in type(self.model).__bases__]
)

@property
def _estimator_type(self):
"""Estimator type is computed dynamically from the given model."""
return self.model._estimator_type

def fit(self, X, y):
"""
Fit the data after adapting the same weight.
Expand Down
30 changes: 22 additions & 8 deletions tests/test_meta/test_decay_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LinearRegression, Ridge, LogisticRegression
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.base import is_regressor, is_classifier


from sklego.common import flatten
Expand All @@ -14,17 +15,13 @@
)


@pytest.mark.parametrize(
"test_fn", flatten([general_checks, regressor_checks])
)
@pytest.mark.parametrize("test_fn", flatten([general_checks, regressor_checks]))
def test_estimator_checks_regression(test_fn):
trf = DecayEstimator(LinearRegression())
test_fn(DecayEstimator.__name__, trf)


@pytest.mark.parametrize(
"test_fn", flatten([general_checks, classifier_checks])
)
@pytest.mark.parametrize("test_fn", flatten([general_checks, classifier_checks]))
def test_estimator_checks_classification(test_fn):
trf = DecayEstimator(LogisticRegression(solver="lbfgs"))
test_fn(DecayEstimator.__name__, trf)
Expand All @@ -36,7 +33,7 @@ def test_estimator_checks_classification(test_fn):
def test_decay_weight_regr(mod):
X, y = np.random.normal(0, 1, (100, 100)), np.random.normal(0, 1, (100,))
mod = DecayEstimator(mod, decay=0.95).fit(X, y)
assert mod.weights_[0] == pytest.approx(0.95 ** 100, abs=0.001)
assert mod.weights_[0] == pytest.approx(0.95**100, abs=0.001)


@pytest.mark.parametrize(
Expand All @@ -48,7 +45,7 @@ def test_decay_weight_clf(mod):
(np.random.normal(0, 1, (100,)) < 0).astype(int),
)
mod = DecayEstimator(mod, decay=0.95).fit(X, y)
assert mod.weights_[0] == pytest.approx(0.95 ** 100, abs=0.001)
assert mod.weights_[0] == pytest.approx(0.95**100, abs=0.001)


@pytest.mark.parametrize("mod", flatten([KNeighborsClassifier()]))
Expand All @@ -58,3 +55,20 @@ def test_throw_warning(mod):
DecayEstimator(mod, decay=0.95).fit(X, y)
assert "sample_weight" in str(e)
assert type(mod).__name__ in str(e)


@pytest.mark.parametrize(
"mod, is_regr",
[
(LinearRegression(), True),
(Ridge(), True),
(DecisionTreeRegressor(), True),
(LogisticRegression(), False),
(DecisionTreeClassifier(), False),
],
)
def test_estimator_type_regressor(mod, is_regr):
mod = DecayEstimator(mod)
assert mod._estimator_type == mod.model._estimator_type
assert is_regressor(mod) == is_regr
assert is_classifier(mod) == (not is_regr)

0 comments on commit b97f70a

Please sign in to comment.