Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOC] Replace use of "estimator" term in base object interfaces with more general references #293

Merged
merged 21 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"contributions": [
"bug",
"code",
"doc",
"ideas",
"maintenance",
"test"
Expand Down
137 changes: 66 additions & 71 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
set/clone dynamic tags - clone_tags(estimator, tag_names=None)

Blueprinting: resetting and cloning, post-init state with same hyper-parameters
reset estimator to post-init - reset()
cloneestimator (copy&reset) - clone()
reset object to post-init - reset()
clone object (copy&reset) - clone()

Testing with default parameters methods
getting default parameters (all sets) - get_test_params()
Expand Down Expand Up @@ -144,10 +144,10 @@
return self

def clone(self):
"""Obtain a clone of the object with same hyper-parameters.
"""Obtain a clone of the object with same parameters and config.

A clone is a different object without shared references, in post-init state.
This function is equivalent to returning sklearn.clone of self.
This function behaves similarly to returning a sklearn clone of self.

Raises
------
Expand Down Expand Up @@ -197,7 +197,7 @@
for p in parameters:
if p.kind == p.VAR_POSITIONAL:
raise RuntimeError(
"scikit-learn compatible estimators should always "
"BaseObject compatible classes should always "
"specify their parameters in the signature"
" of their __init__ (no varargs)."
" %s with constructor %s doesn't "
Expand Down Expand Up @@ -277,7 +277,7 @@
def set_params(self, **params):
"""Set the parameters of this object.

The method works on simple estimators as well as on composite objects.
The method works on simple base objects as well as on composite objects.
Parameter key strings ``<component>__<parameter>`` can be used for composites,
i.e., objects that contain other objects, to access ``<parameter>`` in
the component ``<component>``.
Expand Down Expand Up @@ -320,7 +320,7 @@
valid_params[key] = value

# all matched params have now been set
# reset estimator to clean post-init state with those params
# reset object to clean post-init state with those params
self.reset()

# recurse in components
Expand Down Expand Up @@ -447,7 +447,7 @@
)

def get_tags(self):
"""Get tags from estimator class and dynamic tag overrides.
"""Get tags from object class and dynamic tag overrides.

Returns
-------
Expand All @@ -459,7 +459,7 @@
return self._get_flags(flag_attr_name="_tags")

def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
"""Get tag value from estimator class and dynamic tag overrides.
"""Get tag value from object class and dynamic tag overrides.

Parameters
----------
Expand Down Expand Up @@ -510,11 +510,11 @@
return self

def clone_tags(self, estimator, tag_names=None):
"""Clone tags from another estimator as dynamic override.
"""Clone tags from another object as dynamic override.

Parameters
----------
estimator : estimator inheriting from :class:BaseEstimator
estimator : An instance of :class:BaseObject or derived class
tag_names : str or list of str, default = None
Names of tags to clone. If None then all tags in estimator are used
as `tag_names`.
Expand Down Expand Up @@ -569,7 +569,7 @@

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
"""Return testing parameter settings for the object.

Parameters
----------
Expand All @@ -592,7 +592,7 @@
# if non-default parameters are required, but none have been found, raise error
if len(params_without_defaults) > 0:
raise ValueError(
f"Estimator: {cls} has parameters without default values, "
f"{cls} has parameters without default values, "
f"but these are not set in get_test_params. "
f"Please set them in get_test_params, or provide default values. "
f"Also see the respective extension template, if applicable."
Expand All @@ -605,7 +605,7 @@

@classmethod
def create_test_instance(cls, parameter_set="default"):
"""Construct Estimator instance if possible.
"""Construct a BaseObject instance if possible.

Parameters
----------
Expand Down Expand Up @@ -891,18 +891,18 @@
def set_random_state(self, random_state=None, deep=True, self_policy="copy"):
"""Set random_state pseudo-random seed parameters for self.

Finds ``random_state`` named parameters via ``estimator.get_params``,
Finds ``random_state`` named parameters via ``self.get_params``,
and sets them to integers derived from ``random_state`` via ``set_params``.
These integers are sampled from chain hashing via ``sample_dependent_seed``,
and guarantee pseudo-random independence of seeded random generators.

Applies to ``random_state`` parameters in ``estimator`` depending on
``self_policy``, and remaining component estimators
Applies to ``random_state`` parameters in self depending on
``self_policy``, and remaining component objects
if and only if ``deep=True``.

Note: calls ``set_params`` even if ``self`` does not have a ``random_state``,
or none of the components have a ``random_state`` parameter.
Therefore, ``set_random_state`` will reset any ``scikit-base`` estimator,
Therefore, ``set_random_state`` will reset any ``scikit-base`` object,
even those without a ``random_state`` parameter.

Parameters
Expand All @@ -912,15 +912,15 @@
integers. Pass int for reproducible output across multiple function calls.

deep : bool, default=True
Whether to set the random state in sub-estimators.
Whether to set the random state in sub-objects.
If False, will set only ``self``'s ``random_state`` parameter, if exists.
If True, will set ``random_state`` parameters in sub-estimators as well.
If True, will set ``random_state`` parameters in sub-objects as well.

self_policy : str, one of {"copy", "keep", "new"}, default="copy"

* "copy" : ``estimator.random_state`` is set to input ``random_state``
* "keep" : ``estimator.random_state`` is kept as is
* "new" : ``estimator.random_state`` is set to a new random state,
* "copy" : ``self.random_state`` is set to input ``random_state``
* "keep" : ``self.random_state`` is kept as is
* "new" : ``self.random_state`` is set to a new random state,
derived from input ``random_state``, and in general different from it

Returns
Expand Down Expand Up @@ -972,7 +972,7 @@

@classmethod
def get_class_tags(cls):
"""Get class tags from estimator class and all its parent classes.
"""Get class tags from a base object class and all its parent classes.

Returns
-------
Expand All @@ -987,7 +987,7 @@

@classmethod
def get_class_tag(cls, tag_name, tag_value_default=None):
"""Get tag value from estimator class (only class tags).
"""Get tag value from a base object class (only class tags).

Parameters
----------
Expand All @@ -1008,7 +1008,7 @@
)

def get_tags(self):
"""Get tags from estimator class and dynamic tag overrides.
"""Get tags from a base object class and dynamic tag overrides.

Returns
-------
Expand All @@ -1022,7 +1022,7 @@
return collected_tags

def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
"""Get tag value from estimator class and dynamic tag overrides.
"""Get tag value from a base object class and dynamic tag overrides.

Parameters
----------
Expand Down Expand Up @@ -1218,19 +1218,13 @@
if not deep:
return fitted_params

def sh(x):
"""Shorthand to remove all underscores at end of a string."""
if x.endswith("_"):
return sh(x[:-1])
else:
return x

# add all nested parameters from components that are skbase BaseEstimator
c_dict = self._components()
for c, comp in c_dict.items():
if isinstance(comp, BaseEstimator) and comp._is_fitted:
c_f_params = comp.get_fitted_params(deep=deep)
c_f_params = {f"{sh(c)}__{k}": v for k, v in c_f_params.items()}
c = c.rstrip("_")
c_f_params = {f"{c}__{k}": v for k, v in c_f_params.items()}
fitted_params.update(c_f_params)

# add all nested parameters from components that are sklearn estimators
Expand All @@ -1243,7 +1237,8 @@
for c, comp in old_new_params.items():
if isinstance(comp, self.GET_FITTED_PARAMS_NESTING):
c_f_params = self._get_fitted_params_default(comp)
c_f_params = {f"{sh(c)}__{k}": v for k, v in c_f_params.items()}
c = c.rstrip("_")
c_f_params = {f"{c}__{k}": v for k, v in c_f_params.items()}
new_params.update(c_f_params)
fitted_params.update(new_params)
old_new_params = new_params.copy()
Expand Down Expand Up @@ -1303,60 +1298,61 @@


# Adapted from sklearn's `_clone_parametrized()`
def _clone(estimator, *, safe=True):
"""Construct a new unfitted estimator with the same parameters.
def _clone(obj, *, safe=True):
"""Construct a new object with the same parameters and config.

Clone does a deep copy of the model in an estimator
without actually copying attached data. It returns a new estimator
with the same parameters that has not been fitted on any data.
Clone does a deep copy of the parameters and config values in
a base object instance.

Parameters
----------
estimator : {list, tuple, set} of estimator instance or a single \
estimator instance
The estimator or group of estimators to be cloned.
obj : {list, tuple, set} instances of base objects or a single \
base object instance
The object or group of objects to clone.
safe : bool, default=True
If safe is False, clone will fall back to a deep copy on objects
that are not estimators.
that do not have a `get_params` method.

Returns
-------
estimator : object
The deep copy of the input, an estimator if input is an estimator.
cloned_obj : {list, tuple, set} instances of base objects or a single \
base object instance
The object or group of objects with the same parameters and config
as the original object.

Notes
-----
If the estimator's `random_state` parameter is an integer (or if the
estimator doesn't have a `random_state` parameter), an *exact clone* is
returned: the clone and the original estimator will give the exact same
If the object's `random_state` parameter is an integer (or if the
object doesn't have a `random_state` parameter), an *exact clone* is
returned: the clone and the original object will give the exact same
results. Otherwise, *statistical clone* is returned: the clone might
return different results from the original estimator. More details can be
return different results from the original object. More details can be
found in :ref:`randomness`.
"""
estimator_type = type(estimator)
obj_type = type(obj)
# XXX: not handling dictionaries
if estimator_type in (list, tuple, set, frozenset):
return estimator_type([_clone(e, safe=safe) for e in estimator])
elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
if obj_type in (list, tuple, set, frozenset):
return obj_type(_clone(e, safe=safe) for e in obj)
if not hasattr(obj, "get_params") or isinstance(obj, type):
if not safe:
return deepcopy(estimator)
return deepcopy(obj)
else:
if isinstance(estimator, type):
if isinstance(obj, type):
raise TypeError(
"Cannot clone object. "
+ "You should provide an instance of "
+ "scikit-learn estimator instead of a class."
+ "`BaseObject` instead of a class."
)
else:
raise TypeError(
"Cannot clone object '%s' (type %s): "
"it does not seem to be a scikit-learn "
"estimator as it does not implement a "
"'get_params' method." % (repr(estimator), type(estimator))
"it does not seem to be a BaseObject "
"instance as it does not implement a "
"'get_params' method." % (repr(obj), type(obj))
)

klass = estimator.__class__
new_object_params = estimator.get_params(deep=False)
klass = obj.__class__
new_object_params = obj.get_params(deep=False)
for name, param in new_object_params.items():
new_object_params[name] = _clone(param, safe=False)
new_object = klass(**new_object_params)
Expand All @@ -1369,34 +1365,33 @@
if param1 is not param2:
raise RuntimeError(
"Cannot clone object %s, as the constructor "
"either does not set or modifies parameter %s" % (estimator, name)
"either does not set or modifies parameter %s" % (obj, name)
)

# This is an extension to the original sklearn implementation
if isinstance(estimator, BaseObject) and estimator.get_config()["clone_config"]:
new_object.set_config(**estimator.get_config())
if isinstance(obj, BaseObject) and obj.get_config()["clone_config"]:
new_object.set_config(**obj.get_config())

return new_object


def _check_clone(original, clone):
from skbase.utils.deep_equals import deep_equals

self_params = original.get_params(deep=False)
original_params = original.get_params(deep=False)

# check that all attributes are written to the clone
for attrname in self_params.keys():
for attrname in original_params.keys():
if not hasattr(clone, attrname):
raise RuntimeError(
f"error in {original}.clone, __init__ must write all arguments "
f"to self and not mutate them, but {attrname} was not found. "
f"Please check __init__ of {original}."
)

clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}
clone_params = {attr: getattr(clone, attr) for attr in original_params.keys()}

# check equality of parameters post-clone and pre-clone
clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
clone_attrs_valid, msg = deep_equals(original_params, clone_params, return_msg=True)
if not clone_attrs_valid:
raise RuntimeError(
f"error in {original}.clone, __init__ must write all arguments "
Expand Down
Loading