diff --git a/.all-contributorsrc b/.all-contributorsrc index d30de699..9278b0c4 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -68,6 +68,7 @@ "contributions": [ "bug", "code", + "doc", "ideas", "maintenance", "test" diff --git a/skbase/base/_base.py b/skbase/base/_base.py index 3b0ca27c..73b342f2 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -31,8 +31,8 @@ class name: BaseObject set/clone dynamic tags - clone_tags(estimator, tag_names=None) Blueprinting: resetting and cloning, post-init state with same hyper-parameters - reset estimator to post-init - reset() - cloneestimator (copy&reset) - clone() + reset object to post-init - reset() + clone object (copy&reset) - clone() Testing with default parameters methods getting default parameters (all sets) - get_test_params() @@ -144,10 +144,10 @@ def reset(self): return self def clone(self): - """Obtain a clone of the object with same hyper-parameters. + """Obtain a clone of the object with same hyper-parameters and config. A clone is a different object without shared references, in post-init state. - This function is equivalent to returning sklearn.clone of self. + This function is equivalent to returning ``sklearn.clone`` of ``self``. Raises ------ @@ -197,7 +197,7 @@ def _get_init_signature(cls): for p in parameters: if p.kind == p.VAR_POSITIONAL: raise RuntimeError( - "scikit-learn compatible estimators should always " + "scikit-base compatible classes should always " "specify their parameters in the signature" " of their __init__ (no varargs)." " %s with constructor %s doesn't " @@ -277,7 +277,7 @@ def get_params(self, deep=True): def set_params(self, **params): """Set the parameters of this object. - The method works on simple estimators as well as on composite objects. + The method works on simple skbase objects as well as on composite objects. Parameter key strings ``__`` can be used for composites, i.e., objects that contain other objects, to access ```` in the component ````. @@ -320,7 +320,7 @@ def set_params(self, **params): valid_params[key] = value # all matched params have now been set - # reset estimator to clean post-init state with those params + # reset object to clean post-init state with those params self.reset() # recurse in components @@ -447,7 +447,7 @@ def get_class_tag(cls, tag_name, tag_value_default=None): ) def get_tags(self): - """Get tags from estimator class and dynamic tag overrides. + """Get tags from skbase class and dynamic tag overrides. Returns ------- @@ -459,7 +459,7 @@ class attribute via nested inheritance and then any overrides return self._get_flags(flag_attr_name="_tags") def get_tag(self, tag_name, tag_value_default=None, raise_error=True): - """Get tag value from estimator class and dynamic tag overrides. + """Get tag value from object class and dynamic tag overrides. Parameters ---------- @@ -510,11 +510,11 @@ def set_tags(self, **tag_dict): return self def clone_tags(self, estimator, tag_names=None): - """Clone tags from another estimator as dynamic override. + """Clone tags from another object as dynamic override. Parameters ---------- - estimator : estimator inheriting from :class:BaseEstimator + estimator : An instance of :class:BaseObject or derived class tag_names : str or list of str, default = None Names of tags to clone. If None then all tags in estimator are used as `tag_names`. @@ -569,7 +569,7 @@ def set_config(self, **config_dict): @classmethod def get_test_params(cls, parameter_set="default"): - """Return testing parameter settings for the estimator. + """Return testing parameter settings for the skbase object. Parameters ---------- @@ -592,7 +592,7 @@ def get_test_params(cls, parameter_set="default"): # if non-default parameters are required, but none have been found, raise error if len(params_without_defaults) > 0: raise ValueError( - f"Estimator: {cls} has parameters without default values, " + f"skbase object {cls} has parameters without default values, " f"but these are not set in get_test_params. " f"Please set them in get_test_params, or provide default values. " f"Also see the respective extension template, if applicable." @@ -605,7 +605,7 @@ def get_test_params(cls, parameter_set="default"): @classmethod def create_test_instance(cls, parameter_set="default"): - """Construct Estimator instance if possible. + """Construct an instance of the class, using first test parameter set. Parameters ---------- @@ -891,18 +891,18 @@ def _repr_mimebundle_(self, **kwargs): def set_random_state(self, random_state=None, deep=True, self_policy="copy"): """Set random_state pseudo-random seed parameters for self. - Finds ``random_state`` named parameters via ``estimator.get_params``, + Finds ``random_state`` named parameters via ``self.get_params``, and sets them to integers derived from ``random_state`` via ``set_params``. These integers are sampled from chain hashing via ``sample_dependent_seed``, and guarantee pseudo-random independence of seeded random generators. - Applies to ``random_state`` parameters in ``estimator`` depending on - ``self_policy``, and remaining component estimators + Applies to ``random_state`` parameters in ``self``, depending on + ``self_policy``, and remaining component objects if and only if ``deep=True``. Note: calls ``set_params`` even if ``self`` does not have a ``random_state``, or none of the components have a ``random_state`` parameter. - Therefore, ``set_random_state`` will reset any ``scikit-base`` estimator, + Therefore, ``set_random_state`` will reset any ``scikit-base`` object, even those without a ``random_state`` parameter. Parameters @@ -912,15 +912,18 @@ def set_random_state(self, random_state=None, deep=True, self_policy="copy"): integers. Pass int for reproducible output across multiple function calls. deep : bool, default=True - Whether to set the random state in sub-estimators. - If False, will set only ``self``'s ``random_state`` parameter, if exists. - If True, will set ``random_state`` parameters in sub-estimators as well. + Whether to set the random state in skbase object valued parameters, i.e., + component estimators. + + * If False, will set only ``self``'s ``random_state`` parameter, if exists. + * If True, will set ``random_state`` parameters in component objects + as well. self_policy : str, one of {"copy", "keep", "new"}, default="copy" - * "copy" : ``estimator.random_state`` is set to input ``random_state`` - * "keep" : ``estimator.random_state`` is kept as is - * "new" : ``estimator.random_state`` is set to a new random state, + * "copy" : ``self.random_state`` is set to input ``random_state`` + * "keep" : ``self.random_state`` is kept as is + * "new" : ``self.random_state`` is set to a new random state, derived from input ``random_state``, and in general different from it Returns @@ -972,7 +975,21 @@ def __init__(self): @classmethod def get_class_tags(cls): - """Get class tags from estimator class and all its parent classes. + """Get class tags from class, with tag level inheritance from parent classes. + + Every ``scikit-base`` compatible class has a set of tags, + which are used to store metadata about the object. + + This is a class method, and retrieves tags applicable to the class, + with tag level overrides in the following order of decreasing priority: + + 1. class tags of the class, of which the object is an instance + 2. class tags of all parent classes, in method resolution order + + Instances can override these tags depending on hyper-parameters. + + To retrieve tags with potential instance overrides, use + the ``get_tags`` method instead. Returns ------- @@ -987,7 +1004,22 @@ class attribute via nested inheritance. NOT overridden by dynamic @classmethod def get_class_tag(cls, tag_name, tag_value_default=None): - """Get tag value from estimator class (only class tags). + """Get class tag value from class, with tag level inheritance from parents. + + Every ``scikit-base`` compatible class has a set of tags, + which are used to store metadata about the object. + + This is a class method, and retrieves the value of a tag applicable + to the class, + with tag level overrides in the following order of decreasing priority: + + 1. class tags of the class, of which the object is an instance + 2. class tags of all parent classes, in method resolution order + + Instances can override these tags depending on hyper-parameters. + + To retrieve tag values with potential instance overrides, use + the ``get_tag`` method instead. Parameters ---------- @@ -1008,7 +1040,17 @@ def get_class_tag(cls, tag_name, tag_value_default=None): ) def get_tags(self): - """Get tags from estimator class and dynamic tag overrides. + """Get tags from instance, with tag level inheritance and overrides. + + Every ``scikit-base`` compatible object has a set of tags, + which are used to store metadata about the object. + + This method retrieves all tags as a dictionary, with tag level overrides in the + following order of decreasing priority: + + 1. dynamic tags set at construction, e.g., dependent on hyper-parameters + 2. class tags of the class, of which the object is an instance + 3. class tags of all parent classes, in method resolution order Returns ------- @@ -1022,7 +1064,17 @@ class attribute via nested inheritance and then any overrides return collected_tags def get_tag(self, tag_name, tag_value_default=None, raise_error=True): - """Get tag value from estimator class and dynamic tag overrides. + """Get tag value from instance, with tag level inheritance and overrides. + + Every ``scikit-base`` compatible object has a set of tags, + which are used to store metadata about the object. + + This method retrieves the value of a single tag, with tag level overrides in the + following order of decreasing priority: + + 1. dynamic tags set at construction, e.g., dependent on hyper-parameters + 2. class tags of the class, of which the object is an instance + 3. class tags of all parent classes, in method resolution order Parameters ---------- @@ -1218,19 +1270,13 @@ def get_fitted_params(self, deep=True): if not deep: return fitted_params - def sh(x): - """Shorthand to remove all underscores at end of a string.""" - if x.endswith("_"): - return sh(x[:-1]) - else: - return x - # add all nested parameters from components that are skbase BaseEstimator c_dict = self._components() for c, comp in c_dict.items(): if isinstance(comp, BaseEstimator) and comp._is_fitted: c_f_params = comp.get_fitted_params(deep=deep) - c_f_params = {f"{sh(c)}__{k}": v for k, v in c_f_params.items()} + c = c.rstrip("_") + c_f_params = {f"{c}__{k}": v for k, v in c_f_params.items()} fitted_params.update(c_f_params) # add all nested parameters from components that are sklearn estimators @@ -1243,7 +1289,8 @@ def sh(x): for c, comp in old_new_params.items(): if isinstance(comp, self.GET_FITTED_PARAMS_NESTING): c_f_params = self._get_fitted_params_default(comp) - c_f_params = {f"{sh(c)}__{k}": v for k, v in c_f_params.items()} + c = c.rstrip("_") + c_f_params = {f"{c}__{k}": v for k, v in c_f_params.items()} new_params.update(c_f_params) fitted_params.update(new_params) old_new_params = new_params.copy()