Skip to content

Commit

Permalink
[DOC] Replace use of "estimator" term in base object interfaces with …
Browse files Browse the repository at this point in the history
…more general references (#293)

#### Reference Issues/PRs

Depends on #281

#### What does this implement/fix? Explain your changes.

See #281 (comment)
and relevant replies.
  • Loading branch information
tpvasconcelos authored Sep 22, 2024
1 parent 320c206 commit 0784625
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 37 deletions.
1 change: 1 addition & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"contributions": [
"bug",
"code",
"doc",
"ideas",
"maintenance",
"test"
Expand Down
121 changes: 84 additions & 37 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class name: BaseObject
set/clone dynamic tags - clone_tags(estimator, tag_names=None)
Blueprinting: resetting and cloning, post-init state with same hyper-parameters
reset estimator to post-init - reset()
cloneestimator (copy&reset) - clone()
reset object to post-init - reset()
clone object (copy&reset) - clone()
Testing with default parameters methods
getting default parameters (all sets) - get_test_params()
Expand Down Expand Up @@ -144,10 +144,10 @@ def reset(self):
return self

def clone(self):
"""Obtain a clone of the object with same hyper-parameters.
"""Obtain a clone of the object with same hyper-parameters and config.
A clone is a different object without shared references, in post-init state.
This function is equivalent to returning sklearn.clone of self.
This function is equivalent to returning ``sklearn.clone`` of ``self``.
Raises
------
Expand Down Expand Up @@ -197,7 +197,7 @@ def _get_init_signature(cls):
for p in parameters:
if p.kind == p.VAR_POSITIONAL:
raise RuntimeError(
"scikit-learn compatible estimators should always "
"scikit-base compatible classes should always "
"specify their parameters in the signature"
" of their __init__ (no varargs)."
" %s with constructor %s doesn't "
Expand Down Expand Up @@ -290,7 +290,7 @@ def get_params(self, deep=True):
def set_params(self, **params):
"""Set the parameters of this object.
The method works on simple estimators as well as on composite objects.
The method works on simple skbase objects as well as on composite objects.
Parameter key strings ``<component>__<parameter>`` can be used for composites,
i.e., objects that contain other objects, to access ``<parameter>`` in
the component ``<component>``.
Expand Down Expand Up @@ -333,7 +333,7 @@ def set_params(self, **params):
valid_params[key] = value

# all matched params have now been set
# reset estimator to clean post-init state with those params
# reset object to clean post-init state with those params
self.reset()

# recurse in components
Expand Down Expand Up @@ -460,7 +460,7 @@ def get_class_tag(cls, tag_name, tag_value_default=None):
)

def get_tags(self):
"""Get tags from estimator class and dynamic tag overrides.
"""Get tags from skbase class and dynamic tag overrides.
Returns
-------
Expand All @@ -472,7 +472,7 @@ class attribute via nested inheritance and then any overrides
return self._get_flags(flag_attr_name="_tags")

def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
"""Get tag value from estimator class and dynamic tag overrides.
"""Get tag value from object class and dynamic tag overrides.
Parameters
----------
Expand Down Expand Up @@ -523,11 +523,11 @@ def set_tags(self, **tag_dict):
return self

def clone_tags(self, estimator, tag_names=None):
"""Clone tags from another estimator as dynamic override.
"""Clone tags from another object as dynamic override.
Parameters
----------
estimator : estimator inheriting from :class:BaseEstimator
estimator : An instance of :class:BaseObject or derived class
tag_names : str or list of str, default = None
Names of tags to clone. If None then all tags in estimator are used
as `tag_names`.
Expand Down Expand Up @@ -582,7 +582,7 @@ def set_config(self, **config_dict):

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
"""Return testing parameter settings for the skbase object.
Parameters
----------
Expand All @@ -605,7 +605,7 @@ def get_test_params(cls, parameter_set="default"):
# if non-default parameters are required, but none have been found, raise error
if len(params_without_defaults) > 0:
raise ValueError(
f"Estimator: {cls} has parameters without default values, "
f"skbase object {cls} has parameters without default values, "
f"but these are not set in get_test_params. "
f"Please set them in get_test_params, or provide default values. "
f"Also see the respective extension template, if applicable."
Expand All @@ -618,7 +618,7 @@ def get_test_params(cls, parameter_set="default"):

@classmethod
def create_test_instance(cls, parameter_set="default"):
"""Construct Estimator instance if possible.
"""Construct an instance of the class, using first test parameter set.
Parameters
----------
Expand Down Expand Up @@ -904,18 +904,18 @@ def _repr_mimebundle_(self, **kwargs):
def set_random_state(self, random_state=None, deep=True, self_policy="copy"):
"""Set random_state pseudo-random seed parameters for self.
Finds ``random_state`` named parameters via ``estimator.get_params``,
Finds ``random_state`` named parameters via ``self.get_params``,
and sets them to integers derived from ``random_state`` via ``set_params``.
These integers are sampled from chain hashing via ``sample_dependent_seed``,
and guarantee pseudo-random independence of seeded random generators.
Applies to ``random_state`` parameters in ``estimator`` depending on
``self_policy``, and remaining component estimators
Applies to ``random_state`` parameters in ``self``, depending on
``self_policy``, and remaining component objects
if and only if ``deep=True``.
Note: calls ``set_params`` even if ``self`` does not have a ``random_state``,
or none of the components have a ``random_state`` parameter.
Therefore, ``set_random_state`` will reset any ``scikit-base`` estimator,
Therefore, ``set_random_state`` will reset any ``scikit-base`` object,
even those without a ``random_state`` parameter.
Parameters
Expand All @@ -925,15 +925,18 @@ def set_random_state(self, random_state=None, deep=True, self_policy="copy"):
integers. Pass int for reproducible output across multiple function calls.
deep : bool, default=True
Whether to set the random state in sub-estimators.
If False, will set only ``self``'s ``random_state`` parameter, if exists.
If True, will set ``random_state`` parameters in sub-estimators as well.
Whether to set the random state in skbase object valued parameters, i.e.,
component estimators.
* If False, will set only ``self``'s ``random_state`` parameter, if exists.
* If True, will set ``random_state`` parameters in component objects
as well.
self_policy : str, one of {"copy", "keep", "new"}, default="copy"
* "copy" : ``estimator.random_state`` is set to input ``random_state``
* "keep" : ``estimator.random_state`` is kept as is
* "new" : ``estimator.random_state`` is set to a new random state,
* "copy" : ``self.random_state`` is set to input ``random_state``
* "keep" : ``self.random_state`` is kept as is
* "new" : ``self.random_state`` is set to a new random state,
derived from input ``random_state``, and in general different from it
Returns
Expand Down Expand Up @@ -985,7 +988,21 @@ def __init__(self):

@classmethod
def get_class_tags(cls):
"""Get class tags from estimator class and all its parent classes.
"""Get class tags from class, with tag level inheritance from parent classes.
Every ``scikit-base`` compatible class has a set of tags,
which are used to store metadata about the object.
This is a class method, and retrieves tags applicable to the class,
with tag level overrides in the following order of decreasing priority:
1. class tags of the class, of which the object is an instance
2. class tags of all parent classes, in method resolution order
Instances can override these tags depending on hyper-parameters.
To retrieve tags with potential instance overrides, use
the ``get_tags`` method instead.
Returns
-------
Expand All @@ -1000,7 +1017,22 @@ class attribute via nested inheritance. NOT overridden by dynamic

@classmethod
def get_class_tag(cls, tag_name, tag_value_default=None):
"""Get tag value from estimator class (only class tags).
"""Get class tag value from class, with tag level inheritance from parents.
Every ``scikit-base`` compatible class has a set of tags,
which are used to store metadata about the object.
This is a class method, and retrieves the value of a tag applicable
to the class,
with tag level overrides in the following order of decreasing priority:
1. class tags of the class, of which the object is an instance
2. class tags of all parent classes, in method resolution order
Instances can override these tags depending on hyper-parameters.
To retrieve tag values with potential instance overrides, use
the ``get_tag`` method instead.
Parameters
----------
Expand All @@ -1021,7 +1053,17 @@ def get_class_tag(cls, tag_name, tag_value_default=None):
)

def get_tags(self):
"""Get tags from estimator class and dynamic tag overrides.
"""Get tags from instance, with tag level inheritance and overrides.
Every ``scikit-base`` compatible object has a set of tags,
which are used to store metadata about the object.
This method retrieves all tags as a dictionary, with tag level overrides in the
following order of decreasing priority:
1. dynamic tags set at construction, e.g., dependent on hyper-parameters
2. class tags of the class, of which the object is an instance
3. class tags of all parent classes, in method resolution order
Returns
-------
Expand All @@ -1035,7 +1077,17 @@ class attribute via nested inheritance and then any overrides
return collected_tags

def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
"""Get tag value from estimator class and dynamic tag overrides.
"""Get tag value from instance, with tag level inheritance and overrides.
Every ``scikit-base`` compatible object has a set of tags,
which are used to store metadata about the object.
This method retrieves the value of a single tag, with tag level overrides in the
following order of decreasing priority:
1. dynamic tags set at construction, e.g., dependent on hyper-parameters
2. class tags of the class, of which the object is an instance
3. class tags of all parent classes, in method resolution order
Parameters
----------
Expand Down Expand Up @@ -1231,19 +1283,13 @@ def get_fitted_params(self, deep=True):
if not deep:
return fitted_params

def sh(x):
"""Shorthand to remove all underscores at end of a string."""
if x.endswith("_"):
return sh(x[:-1])
else:
return x

# add all nested parameters from components that are skbase BaseEstimator
c_dict = self._components()
for c, comp in c_dict.items():
if isinstance(comp, BaseEstimator) and comp._is_fitted:
c_f_params = comp.get_fitted_params(deep=deep)
c_f_params = {f"{sh(c)}__{k}": v for k, v in c_f_params.items()}
c = c.rstrip("_")
c_f_params = {f"{c}__{k}": v for k, v in c_f_params.items()}
fitted_params.update(c_f_params)

# add all nested parameters from components that are sklearn estimators
Expand All @@ -1256,7 +1302,8 @@ def sh(x):
for c, comp in old_new_params.items():
if isinstance(comp, self.GET_FITTED_PARAMS_NESTING):
c_f_params = self._get_fitted_params_default(comp)
c_f_params = {f"{sh(c)}__{k}": v for k, v in c_f_params.items()}
c = c.rstrip("_")
c_f_params = {f"{c}__{k}": v for k, v in c_f_params.items()}
new_params.update(c_f_params)
fitted_params.update(new_params)
old_new_params = new_params.copy()
Expand Down

0 comments on commit 0784625

Please sign in to comment.