Skip to content

Commit

Permalink
test for retaining sklearn config
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Nov 10, 2024
1 parent 6aa2221 commit 894c694
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions skbase/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,8 +1123,8 @@ def test_clone_class_rather_than_instance_raises_error(
not _check_soft_dependencies("scikit-learn", severity="none"),
reason="skip test if sklearn is not available",
) # sklearn is part of the dev dependency set, test should be executed with that
def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]):
"""Test clone with keyword parameter set to None."""
def test_clone_sklearn_composite():
"""Test clone with a composite of sklearn and skbase."""
from sklearn.ensemble import GradientBoostingRegressor

sklearn_obj = GradientBoostingRegressor(random_state=5, learning_rate=0.02)
Expand All @@ -1134,6 +1134,23 @@ def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]):
assert composite_set.get_params()["a__random_state"] == 42


@pytest.mark.skipif(
not _check_soft_dependencies("scikit-learn", severity="none"),
reason="skip test if sklearn is not available",
) # sklearn is part of the dev dependency set, test should be executed with that
def test_clone_sklearn_composite_retains_config():
"""Test that clone retains sklearn config if inside skbase composite."""
from sklearn.preprocessing import StandardScaler

sklearn_obj_w_config = StandardScaler().set_output(transform="pandas")

composite = ResetTester(a=sklearn_obj_w_config)
composite_clone = composite.clone()

assert hasattr(composite_clone, "_sklearn_output_config")
assert composite_clone._sklearn_output_config.get("transform", None) == "pandas"


# Tests of BaseObject pretty printing representation inspired by sklearn
def test_baseobject_repr(
fixture_class_parent: Type[Parent],
Expand Down

0 comments on commit 894c694

Please sign in to comment.