From 6b61d44c5cff2236ebb8ef1910750e5509f2e5c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 7 Oct 2023 20:58:11 +0200 Subject: [PATCH] set_params sugar --- skbase/base/_base.py | 114 +++++++++++++++++++++++++++++++++----- skbase/tests/test_base.py | 97 ++++++++++++++++++++++++++++++++ 2 files changed, 196 insertions(+), 15 deletions(-) diff --git a/skbase/base/_base.py b/skbase/base/_base.py index 1cc51ae7..361d234c 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -371,49 +371,133 @@ def get_params(self, deep=True): def set_params(self, **params): """Set the parameters of this object. - The method works on simple estimators as well as on nested objects. - The latter have parameters of the form ``__`` so - that it's possible to update each component of a nested object. + The method works on simple estimators as well as on composite objects. + Parameter key strings ``__`` can be used for composites, + i.e., objects that contain other objects, to access ```` in + the component ````. + The string ````, without ``__``, can also be used if + this makes the reference unambiguous, e.g., there are no two parameters of + components with the name ````. Parameters ---------- **params : dict - BaseObject parameters. + BaseObject parameters, keys must be ``__`` strings. + __ suffixes can alias full strings, if unique among get_params keys. Returns ------- - self - Reference to self (after parameters have been set). + self : reference to self (after parameters have been set) """ if not params: # Simple optimization to gain speed (inspect is slow) return self valid_params = self.get_params(deep=True) + unmatched_keys = [] + nested_params = defaultdict(dict) # grouped by prefix - for key, value in params.items(): - key, delim, sub_key = key.partition("__") + for full_key, value in params.items(): + # split full_key by first occurrence of __, if contains __ + # "key_without_dblunderscore" -> "key_without_dbl_underscore", None, None + # "key__with__dblunderscore" -> "key", "__", "with__dblunderscore" + key, delim, sub_key = full_key.partition("__") + # if key not recognized, remember for suffix matching if key not in valid_params: - raise ValueError( - "Invalid parameter %s for object %s. " - "Check the list of available parameters " - "with `object.get_params().keys()`." % (key, self) - ) - - if delim: + unmatched_keys += [key] + # if full_key contained __, collect suffix for component set_params + elif delim: nested_params[key][sub_key] = value + # if key is found and did not contain __, set self.key to the value else: setattr(self, key, value) valid_params[key] = value + # all matched params have now been set + # reset estimator to clean post-init state with those params self.reset() # recurse in components for key, sub_params in nested_params.items(): valid_params[key].set_params(**sub_params) + # for unmatched keys, resolve by aliasing via available __ suffixes, recurse + if len(unmatched_keys) > 0: + valid_params = self.get_params(deep=True) + unmatched_params = {key: params[key] for key in unmatched_keys} + + # aliasing, syntactic sugar to access uniquely named params more easily + aliased_params = self._alias_params(unmatched_params, valid_params) + + # if none of the parameter names change through aliasing, raise error + if set(aliased_params) == set(unmatched_params): + raise ValueError( + f"Invalid parameter keys provided to set_params of object {self}. " + "Check the list of available parameters " + "with `object.get_params().keys()`. " + f"Invalid keys provided: {unmatched_keys}" + ) + + # recurse: repeat matching and aliasing until no further matches found + # termination condition is above, "no change in keys via aliasing" + self.set_params(**aliased_params) + return self + def _alias_params(self, d, valid_params): + """Replace shorthands in d by full keys from valid_params. + + Parameters + ---------- + d: dict with str keys + valid_params: dict with str keys + + Result + ------ + alias_dict: dict with str keys, all keys in valid_params + values are as in d, with keys replaced by following rule: + If key is a __ suffix of exactly one key in valid_params, + it is replaced by that key. Otherwise an exception is raised. + A __ suffix of a str is any str obtained as suffix from partition by __. + Else, i.e., if key is in valid_params or not a __ suffix, + the key is replaced by itself, i.e., left unchanged. + + Raises + ------ + ValueError if at least one key of d is neither contained in valid_params, + nor is it a __ suffix of exactly one key in valid_params + """ + + def _is_suffix(x, y): + """Return whether x is a strict __ suffix of y.""" + return y.endswith(x) and y.endswith("__" + x) + + def _get_alias(x, d): + """Return alias of x in d.""" + # if key is in valid_params, key is replaced by key (itself) + if any(x == y for y in d.keys()): + return x + + suff_list = [y for y in d.keys() if _is_suffix(x, y)] + + # if key is a __ suffix of exactly one key in valid_params, + # it is replaced by that key + ns = len(suff_list) + if ns > 1: + raise ValueError( + f"suffix {x} does not uniquely determine parameter key, of " + f"{type(self).__name__} instance" + f"the following parameter keys have the same suffix: {suff_list}" + ) + if ns == 0: + return x + # if ns == 1 + return suff_list[0] + + alias_dict = {_get_alias(x, valid_params): d[x] for x in d.keys()} + + return alias_dict + @classmethod def get_class_tags(cls): """Get class tags from the class and all its parent classes. diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index 0408ce5f..98b71efe 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -33,6 +33,8 @@ "test_components", "test_components_raises_error_base_class_is_not_class", "test_components_raises_error_base_class_is_not_baseobject_subclass", + "test_param_alias", + "test_nested_set_params_and_alias", "test_reset", "test_reset_composite", "test_get_init_signature", @@ -545,6 +547,101 @@ class SomeClass: composite._components(SomeClass) +class AliasTester(BaseObject): + def __init__(self, a, bar=42): + self.a = a + self.bar = bar + + +def test_param_alias(): + """Tests parameter aliasing with parameter string shorthands. + + Raises + ------ + AssertionError if parameters that should be set via __ are not set + AssertionError if error that should be raised is not raised + """ + non_composite = AliasTester(a=42, bar=4242) + composite = CompositionDummy(foo=non_composite) + + # this should write to a of foo, because there is only one suffix called a + composite.set_params(**{"a": 424242}) + assert composite.get_params()["foo__a"] == 424242 + + # this should write to bar of composite, because "bar" is a full parameter string + # there is a suffix in foo, but if the full string is there, it writes to that + composite.set_params(**{"bar": 424243}) + assert composite.get_params()["bar"] == 424243 + + # trying to write to bad_param should raise an exception + # since bad_param is neither a suffix nor a full parameter string + with pytest.raises(ValueError, match=r"Invalid parameter keys provided to"): + composite.set_params(**{"bad_param": 424242}) + + # new example: highly nested composite with identical suffixes + non_composite1 = composite + non_composite2 = AliasTester(a=42, bar=4242) + uber_composite = CompositionDummy(foo=non_composite1, bar=non_composite2) + + # trying to write to a should raise an exception + # since there are two suffix a, and a is not a full parameter string + with pytest.raises(ValueError, match=r"does not uniquely determine parameter key"): + uber_composite.set_params(**{"a": 424242}) + + # same as above, should overwrite "bar" of uber_composite + uber_composite.set_params(**{"bar": 424243}) + assert uber_composite.get_params()["bar"] == 424243 + + +def test_nested_set_params_and_alias(): + """Tests that nested param setting works correctly. + + This specifically tests that parameters of components can be provided, + even if that component is not present in the object that set_params is called on, + but is also being set in the same set_params call. + + Also tests alias resolution, using recursive end state after set_params. + + Raises + ------ + AssertionError if parameters that should be set via __ are not set + AssertionError if error that should be raised is not raised + """ + non_composite = AliasTester(a=42, bar=4242) + composite = CompositionDummy(foo=0) + + # this should write to a of foo + # potential error here is that composite does not have foo__a to start with + # so error catching or writing foo__a to early could cause an exception + composite.set_params(**{"foo": non_composite, "foo__a": 424242}) + assert composite.get_params()["foo__a"] == 424242 + + non_composite = AliasTester(a=42, bar=4242) + composite = CompositionDummy(foo=0) + + # same, and recognizing that foo__a is the only matching suffix in the end state + composite.set_params(**{"foo": non_composite, "a": 424242}) + assert composite.get_params()["foo__a"] == 424242 + + # new example: highly nested composite with identical suffixes + non_composite1 = composite + non_composite2 = AliasTester(a=42, bar=4242) + uber_composite = CompositionDummy(foo=42, bar=42) + + # trying to write to a should raise an exception + # since there are two suffix a, and a is not a full parameter string + with pytest.raises(ValueError, match=r"does not uniquely determine parameter key"): + uber_composite.set_params( + **{"a": 424242, "foo": non_composite1, "bar": non_composite2} + ) + + uber_composite = CompositionDummy(foo=non_composite1, bar=42) + + # same as above, should overwrite "bar" of uber_composite + uber_composite.set_params(**{"bar": 424243}) + assert uber_composite.get_params()["bar"] == 424243 + + # Test parameter interface (get_params, set_params, reset and related methods) # Some tests of get_params and set_params are adapted from sklearn tests def test_reset(fixture_reset_tester: Type[ResetTester]):