Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MNT] Refactor .clone() #281

Merged
merged 10 commits into from
Feb 25, 2024
213 changes: 107 additions & 106 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from skbase.base._pretty_printing._object_html_repr import _object_html_repr
from skbase.base._tagmanager import _FlagManager

__author__: List[str] = ["mloning", "RNKuhns", "fkiraly"]
__author__: List[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos"]
__all__: List[str] = ["BaseEstimator", "BaseObject"]


Expand Down Expand Up @@ -157,113 +157,11 @@
-----
If successful, equal in value to ``type(self)(**self.get_params(deep=False))``.
"""
self_params = self.get_params(deep=False)
self_clone = self._clone(self)

# if checking the clone is turned off, return now
if not self.get_config()["check_clone"]:
return self_clone

from skbase.utils.deep_equals import deep_equals

# check that all attributes are written to the clone
for attrname in self_params.keys():
if not hasattr(self_clone, attrname):
raise RuntimeError(
f"error in {self}.clone, __init__ must write all arguments "
f"to self and not mutate them, but {attrname} was not found. "
f"Please check __init__ of {self}."
)

clone_attrs = {attr: getattr(self_clone, attr) for attr in self_params.keys()}

# check equality of parameters post-clone and pre-clone
clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
if not clone_attrs_valid:
raise RuntimeError(
f"error in {self}.clone, __init__ must write all arguments "
f"to self and not mutate them, but this is not the case. "
f"Error on equality check of arguments (x) vs parameters (y): {msg}"
)

self_clone = _clone(self)
if self.get_config()["check_clone"]:
_check_clone(original=self, clone=self_clone)
return self_clone

# copied from sklearn
def _clone(self, estimator, *, safe=True):
"""Construct a new unfitted estimator with the same parameters.

Clone does a deep copy of the model in an estimator
without actually copying attached data. It returns a new estimator
with the same parameters that has not been fitted on any data.

Parameters
----------
estimator : {list, tuple, set} of estimator instance or a single \
estimator instance
The estimator or group of estimators to be cloned.
safe : bool, default=True
If safe is False, clone will fall back to a deep copy on objects
that are not estimators.

Returns
-------
estimator : object
The deep copy of the input, an estimator if input is an estimator.

Notes
-----
If the estimator's `random_state` parameter is an integer (or if the
estimator doesn't have a `random_state` parameter), an *exact clone* is
returned: the clone and the original estimator will give the exact same
results. Otherwise, *statistical clone* is returned: the clone might
return different results from the original estimator. More details can be
found in :ref:`randomness`.
"""
estimator_type = type(estimator)
# XXX: not handling dictionaries
if estimator_type in (list, tuple, set, frozenset):
return estimator_type([self._clone(e, safe=safe) for e in estimator])
elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
if not safe:
return deepcopy(estimator)
else:
if isinstance(estimator, type):
raise TypeError(
"Cannot clone object. "
+ "You should provide an instance of "
+ "scikit-learn estimator instead of a class."
)
else:
raise TypeError(
"Cannot clone object '%s' (type %s): "
"it does not seem to be a scikit-learn "
"estimator as it does not implement a "
"'get_params' method." % (repr(estimator), type(estimator))
)

klass = estimator.__class__
new_object_params = estimator.get_params(deep=False)
for name, param in new_object_params.items():
new_object_params[name] = self._clone(param, safe=False)
new_object = klass(**new_object_params)
params_set = new_object.get_params(deep=False)

# quick sanity check of the parameters of the clone
for name in new_object_params:
param1 = new_object_params[name]
param2 = params_set[name]
if param1 is not param2:
raise RuntimeError(
"Cannot clone object %s, as the constructor "
"either does not set or modifies parameter %s" % (estimator, name)
)

# This is an extension to the original sklearn implementation
if isinstance(estimator, BaseObject) and estimator.get_config()["clone_config"]:
new_object.set_config(**estimator.get_config())

return new_object

@classmethod
def _get_init_signature(cls):
"""Get class init signature.
Expand Down Expand Up @@ -1405,3 +1303,106 @@
fitted parameters, keyed by names of fitted parameter
"""
return self._get_fitted_params_default()


# Adapted from sklearn's `_clone_parametrized()`
def _clone(estimator, *, safe=True):
"""Construct a new unfitted estimator with the same parameters.

Clone does a deep copy of the model in an estimator
without actually copying attached data. It returns a new estimator
with the same parameters that has not been fitted on any data.

Parameters
----------
estimator : {list, tuple, set} of estimator instance or a single \
estimator instance
The estimator or group of estimators to be cloned.
safe : bool, default=True
If safe is False, clone will fall back to a deep copy on objects
that are not estimators.

Returns
-------
estimator : object
The deep copy of the input, an estimator if input is an estimator.

Notes
-----
If the estimator's `random_state` parameter is an integer (or if the
estimator doesn't have a `random_state` parameter), an *exact clone* is
returned: the clone and the original estimator will give the exact same
results. Otherwise, *statistical clone* is returned: the clone might
return different results from the original estimator. More details can be
found in :ref:`randomness`.
"""
estimator_type = type(estimator)
# XXX: not handling dictionaries
if estimator_type in (list, tuple, set, frozenset):
return estimator_type([_clone(e, safe=safe) for e in estimator])
elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
if not safe:
return deepcopy(estimator)
else:
if isinstance(estimator, type):
raise TypeError(

Check warning on line 1348 in skbase/base/_base.py

View check run for this annotation

Codecov / codecov/patch

skbase/base/_base.py#L1347-L1348

Added lines #L1347 - L1348 were not covered by tests
"Cannot clone object. "
+ "You should provide an instance of "
+ "scikit-learn estimator instead of a class."
)
else:
raise TypeError(

Check warning on line 1354 in skbase/base/_base.py

View check run for this annotation

Codecov / codecov/patch

skbase/base/_base.py#L1354

Added line #L1354 was not covered by tests
"Cannot clone object '%s' (type %s): "
"it does not seem to be a scikit-learn "
"estimator as it does not implement a "
"'get_params' method." % (repr(estimator), type(estimator))
)

klass = estimator.__class__
new_object_params = estimator.get_params(deep=False)
for name, param in new_object_params.items():
new_object_params[name] = _clone(param, safe=False)
new_object = klass(**new_object_params)
params_set = new_object.get_params(deep=False)

# quick sanity check of the parameters of the clone
for name in new_object_params:
param1 = new_object_params[name]
param2 = params_set[name]
if param1 is not param2:
raise RuntimeError(
"Cannot clone object %s, as the constructor "
"either does not set or modifies parameter %s" % (estimator, name)
)

# This is an extension to the original sklearn implementation
if isinstance(estimator, BaseObject) and estimator.get_config()["clone_config"]:
new_object.set_config(**estimator.get_config())

return new_object


def _check_clone(original, clone):
from skbase.utils.deep_equals import deep_equals

self_params = original.get_params(deep=False)

# check that all attributes are written to the clone
for attrname in self_params.keys():
if not hasattr(clone, attrname):
raise RuntimeError(

Check warning on line 1393 in skbase/base/_base.py

View check run for this annotation

Codecov / codecov/patch

skbase/base/_base.py#L1392-L1393

Added lines #L1392 - L1393 were not covered by tests
f"error in {original}.clone, __init__ must write all arguments "
f"to self and not mutate them, but {attrname} was not found. "
f"Please check __init__ of {original}."
)

clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}

# check equality of parameters post-clone and pre-clone
clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
if not clone_attrs_valid:
raise RuntimeError(

Check warning on line 1404 in skbase/base/_base.py

View check run for this annotation

Codecov / codecov/patch

skbase/base/_base.py#L1404

Added line #L1404 was not covered by tests
f"error in {original}.clone, __init__ must write all arguments "
f"to self and not mutate them, but this is not the case. "
f"Error on equality check of arguments (x) vs parameters (y): {msg}"
)
4 changes: 4 additions & 0 deletions skbase/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@
SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy()
SKBASE_FUNCTIONS_BY_MODULE.update(
{
"skbase.base._base": (
"_clone",
"_check_clone",
),
"skbase.base._pretty_printing._object_html_repr": (
"_get_visual_block",
"_object_html_repr",
Expand Down
Loading