Skip to content

Commit

Permalink
TST move check_do_not_raise_errors_in_init_or_set_params to common te…
Browse files Browse the repository at this point in the history
  • Loading branch information
adrinjalali authored Sep 11, 2024
1 parent 1d22a48 commit d9deffe
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 23 deletions.
22 changes: 1 addition & 21 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
import re
import warnings
from functools import partial
from inspect import isgenerator, signature
from inspect import isgenerator
from itertools import chain

import numpy as np
import pytest
from scipy.linalg import LinAlgWarning

Expand Down Expand Up @@ -345,25 +344,6 @@ def test_estimators_get_feature_names_out_error(estimator):
check_get_feature_names_out_error(estimator_name, estimator)


@pytest.mark.parametrize(
"Estimator",
[est for name, est in all_estimators()],
)
def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
"""Check that init or set_param does not raise errors."""
params = signature(Estimator).parameters

smoke_test_values = [-1, 3.0, "helloworld", np.array([1.0, 4.0]), [1], {}, []]
for value in smoke_test_values:
new_params = {key: value for key in params}

# Does not raise
est = Estimator(**new_params)

# Also do does not raise
est.set_params(**new_params)


@pytest.mark.parametrize(
"estimator", list(_tested_estimators()), ids=_get_check_estimator_ids
)
Expand Down
17 changes: 17 additions & 0 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _yield_api_checks(estimator):
yield check_no_attributes_set_in_init
yield check_fit_score_takes_y
yield check_estimators_overwrite_params
yield check_do_not_raise_errors_in_init_or_set_params


def _yield_checks(estimator):
Expand Down Expand Up @@ -4689,3 +4690,19 @@ def check_inplace_ensure_writeable(name, estimator_orig):

assert not X.flags.writeable
assert_allclose(X, X_copy)


def check_do_not_raise_errors_in_init_or_set_params(name, estimator_orig):
"""Check that init or set_param does not raise errors."""
Estimator = type(estimator_orig)
params = signature(Estimator).parameters

smoke_test_values = [-1, 3.0, "helloworld", np.array([1.0, 4.0]), [1], {}, []]
for value in smoke_test_values:
new_params = {key: value for key in params}

# Does not raise
est = Estimator(**new_params)

# Also do does not raise
est.set_params(**new_params)
5 changes: 3 additions & 2 deletions sklearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
check_outlier_corruption,
check_regressor_data_not_an_array,
check_requires_y_none,
check_set_params,
set_random_state,
)
from sklearn.utils.fixes import CSR_CONTAINERS, SPARRAY_PRESENT
Expand Down Expand Up @@ -590,9 +591,9 @@ def test_check_estimator():
# check that values returned by get_params match set_params
msg = "get_params result does not match what was passed to set_params"
with raises(AssertionError, match=msg):
check_estimator(ModifiesValueInsteadOfRaisingError())
check_set_params("test", ModifiesValueInsteadOfRaisingError())
with warnings.catch_warnings(record=True) as records:
check_estimator(RaisesErrorInSetParams())
check_set_params("test", RaisesErrorInSetParams())
assert UserWarning in [rec.category for rec in records]

with raises(AssertionError, match=msg):
Expand Down

0 comments on commit d9deffe

Please sign in to comment.