diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5c19a3d13..6015a8eac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v4.4.0 hooks: - id: check-yaml - id: end-of-file-fixer @@ -9,7 +9,7 @@ repos: - id: check-merge-conflict - id: check-added-large-files - repo: https://github.com/psf/black - rev: stable + rev: 23.9.1 hooks: - id: black files: sklego diff --git a/sklego/meta/decay_estimator.py b/sklego/meta/decay_estimator.py index 4a174068b..f439917de 100644 --- a/sklego/meta/decay_estimator.py +++ b/sklego/meta/decay_estimator.py @@ -31,6 +31,11 @@ def _is_classifier(self): ["ClassifierMixin" in p.__name__ for p in type(self.model).__bases__] ) + @property + def _estimator_type(self): + """Estimator type is computed dynamically from the given model.""" + return self.model._estimator_type + def fit(self, X, y): """ Fit the data after adapting the same weight. diff --git a/tests/test_meta/test_decay_estimator.py b/tests/test_meta/test_decay_estimator.py index 5f0b70d12..c0f3f3b9b 100644 --- a/tests/test_meta/test_decay_estimator.py +++ b/tests/test_meta/test_decay_estimator.py @@ -3,6 +3,7 @@ from sklearn.neighbors import KNeighborsClassifier from sklearn.linear_model import LinearRegression, Ridge, LogisticRegression from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor +from sklearn.base import is_regressor, is_classifier from sklego.common import flatten @@ -14,17 +15,13 @@ ) -@pytest.mark.parametrize( - "test_fn", flatten([general_checks, regressor_checks]) -) +@pytest.mark.parametrize("test_fn", flatten([general_checks, regressor_checks])) def test_estimator_checks_regression(test_fn): trf = DecayEstimator(LinearRegression()) test_fn(DecayEstimator.__name__, trf) -@pytest.mark.parametrize( - "test_fn", flatten([general_checks, classifier_checks]) -) +@pytest.mark.parametrize("test_fn", flatten([general_checks, classifier_checks])) def test_estimator_checks_classification(test_fn): trf = DecayEstimator(LogisticRegression(solver="lbfgs")) test_fn(DecayEstimator.__name__, trf) @@ -36,7 +33,7 @@ def test_estimator_checks_classification(test_fn): def test_decay_weight_regr(mod): X, y = np.random.normal(0, 1, (100, 100)), np.random.normal(0, 1, (100,)) mod = DecayEstimator(mod, decay=0.95).fit(X, y) - assert mod.weights_[0] == pytest.approx(0.95 ** 100, abs=0.001) + assert mod.weights_[0] == pytest.approx(0.95**100, abs=0.001) @pytest.mark.parametrize( @@ -48,7 +45,7 @@ def test_decay_weight_clf(mod): (np.random.normal(0, 1, (100,)) < 0).astype(int), ) mod = DecayEstimator(mod, decay=0.95).fit(X, y) - assert mod.weights_[0] == pytest.approx(0.95 ** 100, abs=0.001) + assert mod.weights_[0] == pytest.approx(0.95**100, abs=0.001) @pytest.mark.parametrize("mod", flatten([KNeighborsClassifier()])) @@ -58,3 +55,20 @@ def test_throw_warning(mod): DecayEstimator(mod, decay=0.95).fit(X, y) assert "sample_weight" in str(e) assert type(mod).__name__ in str(e) + + +@pytest.mark.parametrize( + "mod, is_regr", + [ + (LinearRegression(), True), + (Ridge(), True), + (DecisionTreeRegressor(), True), + (LogisticRegression(), False), + (DecisionTreeClassifier(), False), + ], +) +def test_estimator_type_regressor(mod, is_regr): + mod = DecayEstimator(mod) + assert mod._estimator_type == mod.model._estimator_type + assert is_regressor(mod) == is_regr + assert is_classifier(mod) == (not is_regr)