Skip to content

Commit

Permalink
[ENH] set_params to recognize unique suffixes as aliases for full p…
Browse files Browse the repository at this point in the history
…arameter string (#229)

Mirror of sktime/sktime#2931

This PR introduces experimental functionality to `set_params` to enable
it to recognize unique suffixes as aliases for full parameter strings.

That is, in the example in
sktime/sktime#2929, it
enables `set_params` - and, in consequence, the `param_grid` of grid
search - to recognize `max_depth` instead of
`forecaster__forecast__estimator__max_depth`.

This is primarily meant as convenience sugar for the user of grid search
or, more generally, `set_params`.

Contains tests for both nested parameter setting behaviour, as well as
parameter key aliasing in nested composites.
  • Loading branch information
fkiraly authored Oct 25, 2023
1 parent 70c8808 commit d55abc6
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 15 deletions.
114 changes: 99 additions & 15 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,49 +371,133 @@ def get_params(self, deep=True):
def set_params(self, **params):
"""Set the parameters of this object.
The method works on simple estimators as well as on nested objects.
The latter have parameters of the form ``<component>__<parameter>`` so
that it's possible to update each component of a nested object.
The method works on simple estimators as well as on composite objects.
Parameter key strings ``<component>__<parameter>`` can be used for composites,
i.e., objects that contain other objects, to access ``<parameter>`` in
the component ``<component>``.
The string ``<parameter>``, without ``<component>__``, can also be used if
this makes the reference unambiguous, e.g., there are no two parameters of
components with the name ``<parameter>``.
Parameters
----------
**params : dict
BaseObject parameters.
BaseObject parameters, keys must be ``<component>__<parameter>`` strings.
__ suffixes can alias full strings, if unique among get_params keys.
Returns
-------
self
Reference to self (after parameters have been set).
self : reference to self (after parameters have been set)
"""
if not params:
# Simple optimization to gain speed (inspect is slow)
return self
valid_params = self.get_params(deep=True)

unmatched_keys = []

nested_params = defaultdict(dict) # grouped by prefix
for key, value in params.items():
key, delim, sub_key = key.partition("__")
for full_key, value in params.items():
# split full_key by first occurrence of __, if contains __
# "key_without_dblunderscore" -> "key_without_dbl_underscore", None, None
# "key__with__dblunderscore" -> "key", "__", "with__dblunderscore"
key, delim, sub_key = full_key.partition("__")
# if key not recognized, remember for suffix matching
if key not in valid_params:
raise ValueError(
"Invalid parameter %s for object %s. "
"Check the list of available parameters "
"with `object.get_params().keys()`." % (key, self)
)

if delim:
unmatched_keys += [key]
# if full_key contained __, collect suffix for component set_params
elif delim:
nested_params[key][sub_key] = value
# if key is found and did not contain __, set self.key to the value
else:
setattr(self, key, value)
valid_params[key] = value

# all matched params have now been set
# reset estimator to clean post-init state with those params
self.reset()

# recurse in components
for key, sub_params in nested_params.items():
valid_params[key].set_params(**sub_params)

# for unmatched keys, resolve by aliasing via available __ suffixes, recurse
if len(unmatched_keys) > 0:
valid_params = self.get_params(deep=True)
unmatched_params = {key: params[key] for key in unmatched_keys}

# aliasing, syntactic sugar to access uniquely named params more easily
aliased_params = self._alias_params(unmatched_params, valid_params)

# if none of the parameter names change through aliasing, raise error
if set(aliased_params) == set(unmatched_params):
raise ValueError(
f"Invalid parameter keys provided to set_params of object {self}. "
"Check the list of available parameters "
"with `object.get_params().keys()`. "
f"Invalid keys provided: {unmatched_keys}"
)

# recurse: repeat matching and aliasing until no further matches found
# termination condition is above, "no change in keys via aliasing"
self.set_params(**aliased_params)

return self

def _alias_params(self, d, valid_params):
"""Replace shorthands in d by full keys from valid_params.
Parameters
----------
d: dict with str keys
valid_params: dict with str keys
Result
------
alias_dict: dict with str keys, all keys in valid_params
values are as in d, with keys replaced by following rule:
If key is a __ suffix of exactly one key in valid_params,
it is replaced by that key. Otherwise an exception is raised.
A __ suffix of a str is any str obtained as suffix from partition by __.
Else, i.e., if key is in valid_params or not a __ suffix,
the key is replaced by itself, i.e., left unchanged.
Raises
------
ValueError if at least one key of d is neither contained in valid_params,
nor is it a __ suffix of exactly one key in valid_params
"""

def _is_suffix(x, y):
"""Return whether x is a strict __ suffix of y."""
return y.endswith(x) and y.endswith("__" + x)

def _get_alias(x, d):
"""Return alias of x in d."""
# if key is in valid_params, key is replaced by key (itself)
if any(x == y for y in d.keys()):
return x

suff_list = [y for y in d.keys() if _is_suffix(x, y)]

# if key is a __ suffix of exactly one key in valid_params,
# it is replaced by that key
ns = len(suff_list)
if ns > 1:
raise ValueError(
f"suffix {x} does not uniquely determine parameter key, of "
f"{type(self).__name__} instance"
f"the following parameter keys have the same suffix: {suff_list}"
)
if ns == 0:
return x
# if ns == 1
return suff_list[0]

alias_dict = {_get_alias(x, valid_params): d[x] for x in d.keys()}

return alias_dict

@classmethod
def get_class_tags(cls):
"""Get class tags from the class and all its parent classes.
Expand Down
97 changes: 97 additions & 0 deletions skbase/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
"test_components",
"test_components_raises_error_base_class_is_not_class",
"test_components_raises_error_base_class_is_not_baseobject_subclass",
"test_param_alias",
"test_nested_set_params_and_alias",
"test_reset",
"test_reset_composite",
"test_get_init_signature",
Expand Down Expand Up @@ -545,6 +547,101 @@ class SomeClass:
composite._components(SomeClass)


class AliasTester(BaseObject):
def __init__(self, a, bar=42):
self.a = a
self.bar = bar


def test_param_alias():
"""Tests parameter aliasing with parameter string shorthands.
Raises
------
AssertionError if parameters that should be set via __ are not set
AssertionError if error that should be raised is not raised
"""
non_composite = AliasTester(a=42, bar=4242)
composite = CompositionDummy(foo=non_composite)

# this should write to a of foo, because there is only one suffix called a
composite.set_params(**{"a": 424242})
assert composite.get_params()["foo__a"] == 424242

# this should write to bar of composite, because "bar" is a full parameter string
# there is a suffix in foo, but if the full string is there, it writes to that
composite.set_params(**{"bar": 424243})
assert composite.get_params()["bar"] == 424243

# trying to write to bad_param should raise an exception
# since bad_param is neither a suffix nor a full parameter string
with pytest.raises(ValueError, match=r"Invalid parameter keys provided to"):
composite.set_params(**{"bad_param": 424242})

# new example: highly nested composite with identical suffixes
non_composite1 = composite
non_composite2 = AliasTester(a=42, bar=4242)
uber_composite = CompositionDummy(foo=non_composite1, bar=non_composite2)

# trying to write to a should raise an exception
# since there are two suffix a, and a is not a full parameter string
with pytest.raises(ValueError, match=r"does not uniquely determine parameter key"):
uber_composite.set_params(**{"a": 424242})

# same as above, should overwrite "bar" of uber_composite
uber_composite.set_params(**{"bar": 424243})
assert uber_composite.get_params()["bar"] == 424243


def test_nested_set_params_and_alias():
"""Tests that nested param setting works correctly.
This specifically tests that parameters of components can be provided,
even if that component is not present in the object that set_params is called on,
but is also being set in the same set_params call.
Also tests alias resolution, using recursive end state after set_params.
Raises
------
AssertionError if parameters that should be set via __ are not set
AssertionError if error that should be raised is not raised
"""
non_composite = AliasTester(a=42, bar=4242)
composite = CompositionDummy(foo=0)

# this should write to a of foo
# potential error here is that composite does not have foo__a to start with
# so error catching or writing foo__a to early could cause an exception
composite.set_params(**{"foo": non_composite, "foo__a": 424242})
assert composite.get_params()["foo__a"] == 424242

non_composite = AliasTester(a=42, bar=4242)
composite = CompositionDummy(foo=0)

# same, and recognizing that foo__a is the only matching suffix in the end state
composite.set_params(**{"foo": non_composite, "a": 424242})
assert composite.get_params()["foo__a"] == 424242

# new example: highly nested composite with identical suffixes
non_composite1 = composite
non_composite2 = AliasTester(a=42, bar=4242)
uber_composite = CompositionDummy(foo=42, bar=42)

# trying to write to a should raise an exception
# since there are two suffix a, and a is not a full parameter string
with pytest.raises(ValueError, match=r"does not uniquely determine parameter key"):
uber_composite.set_params(
**{"a": 424242, "foo": non_composite1, "bar": non_composite2}
)

uber_composite = CompositionDummy(foo=non_composite1, bar=42)

# same as above, should overwrite "bar" of uber_composite
uber_composite.set_params(**{"bar": 424243})
assert uber_composite.get_params()["bar"] == 424243


# Test parameter interface (get_params, set_params, reset and related methods)
# Some tests of get_params and set_params are adapted from sklearn tests
def test_reset(fixture_reset_tester: Type[ResetTester]):
Expand Down

0 comments on commit d55abc6

Please sign in to comment.