Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOC] Replace use of "estimator" term in base object interfaces with more general references #293

Merged
merged 21 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"contributions": [
"bug",
"code",
"doc",
"ideas",
"maintenance",
"test"
Expand Down
121 changes: 84 additions & 37 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
set/clone dynamic tags - clone_tags(estimator, tag_names=None)

Blueprinting: resetting and cloning, post-init state with same hyper-parameters
reset estimator to post-init - reset()
cloneestimator (copy&reset) - clone()
reset object to post-init - reset()
clone object (copy&reset) - clone()

Testing with default parameters methods
getting default parameters (all sets) - get_test_params()
Expand Down Expand Up @@ -144,10 +144,10 @@
return self

def clone(self):
"""Obtain a clone of the object with same hyper-parameters.
"""Obtain a clone of the object with same hyper-parameters and config.

A clone is a different object without shared references, in post-init state.
This function is equivalent to returning sklearn.clone of self.
This function is equivalent to returning ``sklearn.clone`` of ``self``.

Raises
------
Expand Down Expand Up @@ -197,7 +197,7 @@
for p in parameters:
if p.kind == p.VAR_POSITIONAL:
raise RuntimeError(
"scikit-learn compatible estimators should always "
"scikit-base compatible classes should always "
"specify their parameters in the signature"
" of their __init__ (no varargs)."
" %s with constructor %s doesn't "
Expand Down Expand Up @@ -277,7 +277,7 @@
def set_params(self, **params):
"""Set the parameters of this object.

The method works on simple estimators as well as on composite objects.
The method works on simple skbase objects as well as on composite objects.
Parameter key strings ``<component>__<parameter>`` can be used for composites,
i.e., objects that contain other objects, to access ``<parameter>`` in
the component ``<component>``.
Expand Down Expand Up @@ -320,7 +320,7 @@
valid_params[key] = value

# all matched params have now been set
# reset estimator to clean post-init state with those params
# reset object to clean post-init state with those params
self.reset()

# recurse in components
Expand Down Expand Up @@ -447,7 +447,7 @@
)

def get_tags(self):
"""Get tags from estimator class and dynamic tag overrides.
"""Get tags from skbase class and dynamic tag overrides.

Returns
-------
Expand All @@ -459,7 +459,7 @@
return self._get_flags(flag_attr_name="_tags")

def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
"""Get tag value from estimator class and dynamic tag overrides.
"""Get tag value from object class and dynamic tag overrides.

Parameters
----------
Expand Down Expand Up @@ -510,11 +510,11 @@
return self

def clone_tags(self, estimator, tag_names=None):
"""Clone tags from another estimator as dynamic override.
"""Clone tags from another object as dynamic override.

Parameters
----------
estimator : estimator inheriting from :class:BaseEstimator
estimator : An instance of :class:BaseObject or derived class
tag_names : str or list of str, default = None
Names of tags to clone. If None then all tags in estimator are used
as `tag_names`.
Expand Down Expand Up @@ -569,7 +569,7 @@

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
"""Return testing parameter settings for the skbase object.

Parameters
----------
Expand All @@ -592,7 +592,7 @@
# if non-default parameters are required, but none have been found, raise error
if len(params_without_defaults) > 0:
raise ValueError(
f"Estimator: {cls} has parameters without default values, "
f"skbase object {cls} has parameters without default values, "
f"but these are not set in get_test_params. "
f"Please set them in get_test_params, or provide default values. "
f"Also see the respective extension template, if applicable."
Expand All @@ -605,7 +605,7 @@

@classmethod
def create_test_instance(cls, parameter_set="default"):
"""Construct Estimator instance if possible.
"""Construct an instance of the class, using first test parameter set.

Parameters
----------
Expand Down Expand Up @@ -891,18 +891,18 @@
def set_random_state(self, random_state=None, deep=True, self_policy="copy"):
"""Set random_state pseudo-random seed parameters for self.

Finds ``random_state`` named parameters via ``estimator.get_params``,
Finds ``random_state`` named parameters via ``self.get_params``,
and sets them to integers derived from ``random_state`` via ``set_params``.
These integers are sampled from chain hashing via ``sample_dependent_seed``,
and guarantee pseudo-random independence of seeded random generators.

Applies to ``random_state`` parameters in ``estimator`` depending on
``self_policy``, and remaining component estimators
Applies to ``random_state`` parameters in ``self``, depending on
``self_policy``, and remaining component objects
if and only if ``deep=True``.

Note: calls ``set_params`` even if ``self`` does not have a ``random_state``,
or none of the components have a ``random_state`` parameter.
Therefore, ``set_random_state`` will reset any ``scikit-base`` estimator,
Therefore, ``set_random_state`` will reset any ``scikit-base`` object,
even those without a ``random_state`` parameter.

Parameters
Expand All @@ -912,15 +912,18 @@
integers. Pass int for reproducible output across multiple function calls.

deep : bool, default=True
Whether to set the random state in sub-estimators.
If False, will set only ``self``'s ``random_state`` parameter, if exists.
If True, will set ``random_state`` parameters in sub-estimators as well.
Whether to set the random state in skbase object valued parameters, i.e.,
component estimators.

* If False, will set only ``self``'s ``random_state`` parameter, if exists.
* If True, will set ``random_state`` parameters in component objects
as well.

self_policy : str, one of {"copy", "keep", "new"}, default="copy"

* "copy" : ``estimator.random_state`` is set to input ``random_state``
* "keep" : ``estimator.random_state`` is kept as is
* "new" : ``estimator.random_state`` is set to a new random state,
* "copy" : ``self.random_state`` is set to input ``random_state``
* "keep" : ``self.random_state`` is kept as is
* "new" : ``self.random_state`` is set to a new random state,
derived from input ``random_state``, and in general different from it

Returns
Expand Down Expand Up @@ -972,7 +975,21 @@

@classmethod
def get_class_tags(cls):
"""Get class tags from estimator class and all its parent classes.
"""Get class tags from class, with tag level inheritance from parent classes.

Every ``scikit-base`` compatible class has a set of tags,
which are used to store metadata about the object.

This is a class method, and retrieves tags applicable to the class,
with tag level overrides in the following order of decreasing priority:

1. class tags of the class, of which the object is an instance
2. class tags of all parent classes, in method resolution order

Check warning on line 987 in skbase/base/_base.py

View check run for this annotation

Codecov / codecov/patch

skbase/base/_base.py#L987

Added line #L987 was not covered by tests

Instances can override these tags depending on hyper-parameters.

To retrieve tags with potential instance overrides, use
the ``get_tags`` method instead.

Returns
-------
Expand All @@ -987,7 +1004,22 @@

@classmethod
def get_class_tag(cls, tag_name, tag_value_default=None):
"""Get tag value from estimator class (only class tags).
"""Get class tag value from class, with tag level inheritance from parents.

Every ``scikit-base`` compatible class has a set of tags,
which are used to store metadata about the object.

This is a class method, and retrieves the value of a tag applicable
to the class,
with tag level overrides in the following order of decreasing priority:

1. class tags of the class, of which the object is an instance

Check warning on line 1016 in skbase/base/_base.py

View check run for this annotation

Codecov / codecov/patch

skbase/base/_base.py#L1014-L1016

Added lines #L1014 - L1016 were not covered by tests
2. class tags of all parent classes, in method resolution order

Instances can override these tags depending on hyper-parameters.

To retrieve tag values with potential instance overrides, use
the ``get_tag`` method instead.

Parameters
----------
Expand All @@ -1008,7 +1040,17 @@
)

def get_tags(self):
"""Get tags from estimator class and dynamic tag overrides.
"""Get tags from instance, with tag level inheritance and overrides.

Every ``scikit-base`` compatible object has a set of tags,
which are used to store metadata about the object.

This method retrieves all tags as a dictionary, with tag level overrides in the
following order of decreasing priority:

1. dynamic tags set at construction, e.g., dependent on hyper-parameters

Check warning on line 1051 in skbase/base/_base.py

View check run for this annotation

Codecov / codecov/patch

skbase/base/_base.py#L1050-L1051

Added lines #L1050 - L1051 were not covered by tests
2. class tags of the class, of which the object is an instance
3. class tags of all parent classes, in method resolution order

Returns
-------
Expand All @@ -1022,7 +1064,17 @@
return collected_tags

def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
"""Get tag value from estimator class and dynamic tag overrides.
"""Get tag value from instance, with tag level inheritance and overrides.

Every ``scikit-base`` compatible object has a set of tags,
which are used to store metadata about the object.

This method retrieves the value of a single tag, with tag level overrides in the
following order of decreasing priority:

1. dynamic tags set at construction, e.g., dependent on hyper-parameters
2. class tags of the class, of which the object is an instance
3. class tags of all parent classes, in method resolution order

Check warning on line 1077 in skbase/base/_base.py

View check run for this annotation

Codecov / codecov/patch

skbase/base/_base.py#L1075-L1077

Added lines #L1075 - L1077 were not covered by tests

Parameters
----------
Expand Down Expand Up @@ -1218,19 +1270,13 @@
if not deep:
return fitted_params

def sh(x):
"""Shorthand to remove all underscores at end of a string."""
if x.endswith("_"):
return sh(x[:-1])
else:
return x

# add all nested parameters from components that are skbase BaseEstimator
c_dict = self._components()
for c, comp in c_dict.items():
if isinstance(comp, BaseEstimator) and comp._is_fitted:
c_f_params = comp.get_fitted_params(deep=deep)
c_f_params = {f"{sh(c)}__{k}": v for k, v in c_f_params.items()}
c = c.rstrip("_")
c_f_params = {f"{c}__{k}": v for k, v in c_f_params.items()}
fitted_params.update(c_f_params)

# add all nested parameters from components that are sklearn estimators
Expand All @@ -1243,7 +1289,8 @@
for c, comp in old_new_params.items():
if isinstance(comp, self.GET_FITTED_PARAMS_NESTING):
c_f_params = self._get_fitted_params_default(comp)
c_f_params = {f"{sh(c)}__{k}": v for k, v in c_f_params.items()}
c = c.rstrip("_")
c_f_params = {f"{c}__{k}": v for k, v in c_f_params.items()}
new_params.update(c_f_params)
fitted_params.update(new_params)
old_new_params = new_params.copy()
Expand Down
Loading