diff --git a/sklego/meta/decay_estimator.py b/sklego/meta/decay_estimator.py index d3a4c0c7c..65f005b06 100644 --- a/sklego/meta/decay_estimator.py +++ b/sklego/meta/decay_estimator.py @@ -16,15 +16,27 @@ class DecayEstimator(BaseEstimator): This meta estimator will only work for estimators that have a "sample_weights" argument in their `.fit()` method. + The `fit` method computes the weights to pass to the estimator. + + .. warning:: By default all the checks on the inputs `X` and `y` are delegated to the wrapped estimator. + + To change such behaviour, set `check_input` to `True`. + + Remark that if the check is skipped, then `y` should have a `shape` attrbute, which is + used to extract the number of samples in training data, and compute the weights. + The DecayEstimator will use exponential decay to weight the parameters. w_{t-1} = decay * w_{t} """ - def __init__(self, model, decay: float = 0.999, decay_func="exponential"): + def __init__( + self, model, decay: float = 0.999, decay_func="exponential", check_input=False + ): self.model = model self.decay = decay self.decay_func = decay_func + self.check_input = check_input def _is_classifier(self): return any( @@ -40,12 +52,17 @@ def fit(self, X, y): """ Fit the data after adapting the same weight. - :param X: array-like, shape=(n_columns, n_samples,) training data. - :param y: array-like, shape=(n_samples,) training data. + :param X: array-like, shape=(n_samples, n_features,) training data. + :param y: array-like, shape=(n_samples,) target values. :return: Returns an instance of self. """ - X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES) - self.weights_ = np.cumprod(np.ones(X.shape[0]) * self.decay)[::-1] + + if self.check_input: + X, y = check_X_y( + X, y, estimator=self, dtype=FLOAT_DTYPES, ensure_min_features=0 + ) + + self.weights_ = np.cumprod(np.ones(y.shape[0]) * self.decay)[::-1] self.estimator_ = clone(self.model) try: self.estimator_.fit(X, y, sample_weight=self.weights_) @@ -62,7 +79,7 @@ def predict(self, X): """ Predict new data. - :param X: array-like, shape=(n_columns, n_samples,) training data. + :param X: array-like, shape=(n_samples, n_features,) data to predict. :return: array, shape=(n_samples,) the predicted data """ if self._is_classifier(): diff --git a/tests/test_meta/test_decay_estimator.py b/tests/test_meta/test_decay_estimator.py index c0f3f3b9b..642b511f4 100644 --- a/tests/test_meta/test_decay_estimator.py +++ b/tests/test_meta/test_decay_estimator.py @@ -17,13 +17,13 @@ @pytest.mark.parametrize("test_fn", flatten([general_checks, regressor_checks])) def test_estimator_checks_regression(test_fn): - trf = DecayEstimator(LinearRegression()) + trf = DecayEstimator(LinearRegression(), check_input=True) test_fn(DecayEstimator.__name__, trf) @pytest.mark.parametrize("test_fn", flatten([general_checks, classifier_checks])) def test_estimator_checks_classification(test_fn): - trf = DecayEstimator(LogisticRegression(solver="lbfgs")) + trf = DecayEstimator(LogisticRegression(solver="lbfgs"), check_input=True) test_fn(DecayEstimator.__name__, trf)