diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index 3f053ed2..435453ea 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -1123,8 +1123,8 @@ def test_clone_class_rather_than_instance_raises_error( not _check_soft_dependencies("scikit-learn", severity="none"), reason="skip test if sklearn is not available", ) # sklearn is part of the dev dependency set, test should be executed with that -def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]): - """Test clone with keyword parameter set to None.""" +def test_clone_sklearn_composite(): + """Test clone with a composite of sklearn and skbase.""" from sklearn.ensemble import GradientBoostingRegressor sklearn_obj = GradientBoostingRegressor(random_state=5, learning_rate=0.02) @@ -1134,6 +1134,23 @@ def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]): assert composite_set.get_params()["a__random_state"] == 42 +@pytest.mark.skipif( + not _check_soft_dependencies("scikit-learn", severity="none"), + reason="skip test if sklearn is not available", +) # sklearn is part of the dev dependency set, test should be executed with that +def test_clone_sklearn_composite_retains_config(): + """Test that clone retains sklearn config if inside skbase composite.""" + from sklearn.preprocessing import StandardScaler + + sklearn_obj_w_config = StandardScaler().set_output(transform="pandas") + + composite = ResetTester(a=sklearn_obj_w_config) + composite_clone = composite.clone() + + assert hasattr(composite_clone, "_sklearn_output_config") + assert composite_clone._sklearn_output_config.get("transform", None) == "pandas" + + # Tests of BaseObject pretty printing representation inspired by sklearn def test_baseobject_repr( fixture_class_parent: Type[Parent],