From 226318db517e2e905b60122356b4de0494c89786 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Sun, 12 Mar 2023 20:47:33 -0400 Subject: [PATCH 01/16] Add pretty printing functionality --- skbase/base/_base.py | 95 ++++ .../_pretty_printing/_object_html_repr.py | 398 +++++++++++++++++ skbase/base/_pretty_printing/_pprint.py | 412 ++++++++++++++++++ skbase/config/__init__.py | 393 +++++++++++++++++ skbase/config/tests/__init__.py | 4 + skbase/config/tests/test_config.py | 103 +++++ skbase/tests/conftest.py | 23 +- skbase/tests/test_base.py | 5 +- skbase/utils/_check.py | 53 +++ skbase/utils/tests/test_check.py | 24 + 10 files changed, 1507 insertions(+), 3 deletions(-) create mode 100644 skbase/base/_pretty_printing/_object_html_repr.py create mode 100644 skbase/base/_pretty_printing/_pprint.py create mode 100644 skbase/config/__init__.py create mode 100644 skbase/config/tests/__init__.py create mode 100644 skbase/config/tests/test_config.py create mode 100644 skbase/utils/_check.py create mode 100644 skbase/utils/tests/test_check.py diff --git a/skbase/base/_base.py b/skbase/base/_base.py index 970bf350..fdf4d393 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -53,6 +53,7 @@ class name: BaseEstimator fitted state check - check_is_fitted (raises error if not is_fitted) """ import inspect +import re import warnings from collections import defaultdict from copy import deepcopy @@ -62,7 +63,9 @@ class name: BaseEstimator from sklearn.base import BaseEstimator as _BaseEstimator from skbase._exceptions import NotFittedError +from skbase.base._pretty_printing._object_html_repr import _object_html_repr from skbase.base._tagmanager import _FlagManager +from skbase.config import get_config # type: ignore __author__: List[str] = ["mloning", "RNKuhns", "fkiraly"] __all__: List[str] = ["BaseEstimator", "BaseObject"] @@ -682,6 +685,98 @@ def _components(self, base_class=None): return comp_dict + def __repr__(self, n_char_max: int = 700): + """Represent class as string. + + This follows the scikit-learn implementation for the string representation + of parameterized objects. + + Parameters + ---------- + n_char_max : int + Maximum (approximate) number of non-blank characters to render. This + can be useful in testing. + """ + from skbase.base._pretty_printing._pprint import _BaseObjectPrettyPrinter + + n_max_elements_to_show = 30 # number of elements to show in sequences + # use ellipsis for sequences with a lot of elements + pp = _BaseObjectPrettyPrinter( + compact=True, + indent=1, + indent_at_name=True, + n_max_elements_to_show=n_max_elements_to_show, + changed_only=get_config()["print_changed_only"], + ) + + repr_ = pp.pformat(self) + + # Use bruteforce ellipsis when there are a lot of non-blank characters + n_nonblank = len("".join(repr_.split())) + if n_nonblank > n_char_max: + lim = n_char_max // 2 # apprx number of chars to keep on both ends + regex = r"^(\s*\S){%d}" % lim + # The regex '^(\s*\S){%d}' matches from the start of the string + # until the nth non-blank character: + # - ^ matches the start of string + # - (pattern){n} matches n repetitions of pattern + # - \s*\S matches a non-blank char following zero or more blanks + left_match = re.match(regex, repr_) + right_match = re.match(regex, repr_[::-1]) + left_lim = left_match.end() if left_match is not None else 0 + right_lim = right_match.end() if right_match is not None else 0 + + if "\n" in repr_[left_lim:-right_lim]: + # The left side and right side aren't on the same line. + # To avoid weird cuts, e.g.: + # categoric...ore', + # we need to start the right side with an appropriate newline + # character so that it renders properly as: + # categoric... + # handle_unknown='ignore', + # so we add [^\n]*\n which matches until the next \n + regex += r"[^\n]*\n" + right_match = re.match(regex, repr_[::-1]) + right_lim = right_match.end() if right_match is not None else 0 + + ellipsis = "..." + if left_lim + len(ellipsis) < len(repr_) - right_lim: + # Only add ellipsis if it results in a shorter repr + repr_ = repr_[:left_lim] + "..." + repr_[-right_lim:] + + return repr_ + + @property + def _repr_html_(self): + """HTML representation of BaseObject. + + This is redundant with the logic of `_repr_mimebundle_`. The latter + should be favorted in the long term, `_repr_html_` is only + implemented for consumers who do not interpret `_repr_mimbundle_`. + """ + if get_config()["display"] != "diagram": + raise AttributeError( + "_repr_html_ is only defined when the " + "`display` configuration option is set to 'diagram'." + ) + return self._repr_html_inner + + def _repr_html_inner(self): + """Return HTML representation of class. + + This function is returned by the @property `_repr_html_` to make + `hasattr(BaseObject, "_repr_html_") return `True` or `False` depending + on `get_config()["display"]`. + """ + return _object_html_repr(self) + + def _repr_mimebundle_(self, **kwargs): + """Mime bundle used by jupyter kernels to display instances of BaseObject.""" + output = {"text/plain": repr(self)} + if get_config()["display"] == "diagram": + output["text/html"] = _object_html_repr(self) + return output + class TagAliaserMixin: """Mixin class for tag aliasing and deprecation of old tags. diff --git a/skbase/base/_pretty_printing/_object_html_repr.py b/skbase/base/_pretty_printing/_object_html_repr.py new file mode 100644 index 00000000..650a2072 --- /dev/null +++ b/skbase/base/_pretty_printing/_object_html_repr.py @@ -0,0 +1,398 @@ +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Many elements of this code were developed in scikit-learn. These elements +# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For +# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +"""Functionality to represent instance of BaseObject as html.""" + +import html +import uuid +from contextlib import closing, suppress +from io import StringIO +from string import Template + +from skbase.config import config_context # type: ignore + +__author__ = ["RNKuhns"] + + +class _VisualBlock: + """HTML Representation of BaseObject. + + Parameters + ---------- + kind : {'serial', 'parallel', 'single'} + kind of HTML block + + objs : list of BaseObjects or `_VisualBlock`s or a single BaseObject + If kind != 'single', then `objs` is a list of + BaseObjects. If kind == 'single', then `objs` is a single BaseObject. + + names : list of str, default=None + If kind != 'single', then `names` corresponds to BaseObjects. + If kind == 'single', then `names` is a single string corresponding to + the single BaseObject. + + name_details : list of str, str, or None, default=None + If kind != 'single', then `name_details` corresponds to `names`. + If kind == 'single', then `name_details` is a single string + corresponding to the single BaseObject. + + dash_wrapped : bool, default=True + If true, wrapped HTML element will be wrapped with a dashed border. + Only active when kind != 'single'. + """ + + def __init__(self, kind, objs, *, names=None, name_details=None, dash_wrapped=True): + self.kind = kind + self.objs = objs + self.dash_wrapped = dash_wrapped + + if self.kind in ("parallel", "serial"): + if names is None: + names = (None,) * len(objs) + if name_details is None: + name_details = (None,) * len(objs) + + self.names = names + self.name_details = name_details + + def _sk_visual_block_(self): + return self + + +def _write_label_html( + out, + name, + name_details, + outer_class="sk-label-container", + inner_class="sk-label", + checked=False, +): + """Write labeled html with or without a dropdown with named details.""" + out.write(f'
') + name = html.escape(name) + + if name_details is not None: + name_details = html.escape(str(name_details)) + label_class = "sk-toggleable__label sk-toggleable__label-arrow" + + checked_str = "checked" if checked else "" + est_id = uuid.uuid4() + out.write( + '' + f'' + f'
{name_details}'
+            "
" + ) + else: + out.write(f"") + out.write("
") # outer_class inner_class + + +def _get_visual_block(base_object): + """Generate information about how to display a BaseObject.""" + with suppress(AttributeError): + return base_object._sk_visual_block_() + + if isinstance(base_object, str): + return _VisualBlock( + "single", base_object, names=base_object, name_details=base_object + ) + elif base_object is None: + return _VisualBlock("single", base_object, names="None", name_details="None") + + # check if base_object looks like a meta base_object wraps base_object + if hasattr(base_object, "get_params"): + base_objects = [] + for key, value in base_object.get_params().items(): + # Only look at the BaseObjects in the first layer + if "__" not in key and hasattr(value, "get_params"): + base_objects.append(value) + if len(base_objects): + return _VisualBlock("parallel", base_objects, names=None) + + return _VisualBlock( + "single", + base_object, + names=base_object.__class__.__name__, + name_details=str(base_object), + ) + + +def _write_base_object_html( + out, base_object, base_object_label, base_object_label_details, first_call=False +): + """Write BaseObject to html in serial, parallel, or by itself (single).""" + if first_call: + est_block = _get_visual_block(base_object) + else: + with config_context(print_changed_only=True): + est_block = _get_visual_block(base_object) + + if est_block.kind in ("serial", "parallel"): + dashed_wrapped = first_call or est_block.dash_wrapped + dash_cls = " sk-dashed-wrapped" if dashed_wrapped else "" + out.write(f'
') + + if base_object_label: + _write_label_html(out, base_object_label, base_object_label_details) + + kind = est_block.kind + out.write(f'
') + est_infos = zip(est_block.objs, est_block.names, est_block.name_details) + + for est, name, name_details in est_infos: + if kind == "serial": + _write_base_object_html(out, est, name, name_details) + else: # parallel + out.write('
') + # wrap element in a serial visualblock + serial_block = _VisualBlock("serial", [est], dash_wrapped=False) + _write_base_object_html(out, serial_block, name, name_details) + out.write("
") # sk-parallel-item + + out.write("
") + elif est_block.kind == "single": + _write_label_html( + out, + est_block.names, + est_block.name_details, + outer_class="sk-item", + inner_class="sk-estimator", + checked=first_call, + ) + + +_STYLE = """ +#$id { + color: black; + background-color: white; +} +#$id pre{ + padding: 0; +} +#$id div.sk-toggleable { + background-color: white; +} +#$id label.sk-toggleable__label { + cursor: pointer; + display: block; + width: 100%; + margin-bottom: 0; + padding: 0.3em; + box-sizing: border-box; + text-align: center; +} +#$id label.sk-toggleable__label-arrow:before { + content: "▸"; + float: left; + margin-right: 0.25em; + color: #696969; +} +#$id label.sk-toggleable__label-arrow:hover:before { + color: black; +} +#$id div.sk-estimator:hover label.sk-toggleable__label-arrow:before { + color: black; +} +#$id div.sk-toggleable__content { + max-height: 0; + max-width: 0; + overflow: hidden; + text-align: left; + background-color: #f0f8ff; +} +#$id div.sk-toggleable__content pre { + margin: 0.2em; + color: black; + border-radius: 0.25em; + background-color: #f0f8ff; +} +#$id input.sk-toggleable__control:checked~div.sk-toggleable__content { + max-height: 200px; + max-width: 100%; + overflow: auto; +} +#$id input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before { + content: "▾"; +} +#$id div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label { + background-color: #d4ebff; +} +#$id div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label { + background-color: #d4ebff; +} +#$id input.sk-hidden--visually { + border: 0; + clip: rect(1px 1px 1px 1px); + clip: rect(1px, 1px, 1px, 1px); + height: 1px; + margin: -1px; + overflow: hidden; + padding: 0; + position: absolute; + width: 1px; +} +#$id div.sk-estimator { + font-family: monospace; + background-color: #f0f8ff; + border: 1px dotted black; + border-radius: 0.25em; + box-sizing: border-box; + margin-bottom: 0.5em; +} +#$id div.sk-estimator:hover { + background-color: #d4ebff; +} +#$id div.sk-parallel-item::after { + content: ""; + width: 100%; + border-bottom: 1px solid gray; + flex-grow: 1; +} +#$id div.sk-label:hover label.sk-toggleable__label { + background-color: #d4ebff; +} +#$id div.sk-serial::before { + content: ""; + position: absolute; + border-left: 1px solid gray; + box-sizing: border-box; + top: 2em; + bottom: 0; + left: 50%; +} +#$id div.sk-serial { + display: flex; + flex-direction: column; + align-items: center; + background-color: white; + padding-right: 0.2em; + padding-left: 0.2em; +} +#$id div.sk-item { + z-index: 1; +} +#$id div.sk-parallel { + display: flex; + align-items: stretch; + justify-content: center; + background-color: white; +} +#$id div.sk-parallel::before { + content: ""; + position: absolute; + border-left: 1px solid gray; + box-sizing: border-box; + top: 2em; + bottom: 0; + left: 50%; +} +#$id div.sk-parallel-item { + display: flex; + flex-direction: column; + position: relative; + background-color: white; +} +#$id div.sk-parallel-item:first-child::after { + align-self: flex-end; + width: 50%; +} +#$id div.sk-parallel-item:last-child::after { + align-self: flex-start; + width: 50%; +} +#$id div.sk-parallel-item:only-child::after { + width: 0; +} +#$id div.sk-dashed-wrapped { + border: 1px dashed gray; + margin: 0 0.4em 0.5em 0.4em; + box-sizing: border-box; + padding-bottom: 0.4em; + background-color: white; + position: relative; +} +#$id div.sk-label label { + font-family: monospace; + font-weight: bold; + background-color: white; + display: inline-block; + line-height: 1.2em; +} +#$id div.sk-label-container { + position: relative; + z-index: 2; + text-align: center; +} +#$id div.sk-container { + /* jupyter's `normalize.less` sets `[hidden] { display: none; }` + but bootstrap.min.css set `[hidden] { display: none !important; }` + so we also need the `!important` here to be able to override the + default hidden behavior on the sphinx rendered scikit-learn.org. + See: https://github.com/scikit-learn/scikit-learn/issues/21755 */ + display: inline-block !important; + position: relative; +} +#$id div.sk-text-repr-fallback { + display: none; +} +""".replace( + " ", "" +).replace( + "\n", "" +) # noqa + + +def _object_html_repr(base_object): + """Build a HTML representation of a BaseObject. + + Parameters + ---------- + base_object : base object + The BaseObject or inheritting class to visualize. + + Returns + ------- + html: str + HTML representation of BaseObject. + """ + with closing(StringIO()) as out: + container_id = "sk-" + str(uuid.uuid4()) + style_template = Template(_STYLE) + style_with_id = style_template.substitute(id=container_id) + base_object_str = str(base_object) + + # The fallback message is shown by default and loading the CSS sets + # div.sk-text-repr-fallback to display: none to hide the fallback message. + # + # If the notebook is trusted, the CSS is loaded which hides the fallback + # message. If the notebook is not trusted, then the CSS is not loaded and the + # fallback message is shown by default. + # + # The reverse logic applies to HTML repr div.sk-container. + # div.sk-container is hidden by default and the loading the CSS displays it. + fallback_msg = ( + "Please rerun this cell to show the HTML repr or trust the notebook." + ) + out.write( + f"" + f'
' + '
' + f"
{html.escape(base_object_str)}
{fallback_msg}" + "
" + '
") + + html_output = out.getvalue() + return html_output diff --git a/skbase/base/_pretty_printing/_pprint.py b/skbase/base/_pretty_printing/_pprint.py new file mode 100644 index 00000000..2e8a8a80 --- /dev/null +++ b/skbase/base/_pretty_printing/_pprint.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Many elements of this code were developed in scikit-learn. These elements +# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For +# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +"""Utility functionality for pretty-printing objects used in BaseObject.__repr__.""" +import inspect +import pprint +from collections import OrderedDict + +from skbase.base import BaseObject + +# from skbase.config import get_config +from skbase.utils._check import _is_scalar_nan + + +class KeyValTuple(tuple): + """Dummy class for correctly rendering key-value tuples from dicts.""" + + def __repr__(self): + """Represent as string.""" + # needed for _dispatch[tuple.__repr__] not to be overridden + return super().__repr__() + + +class KeyValTupleParam(KeyValTuple): + """Dummy class for correctly rendering key-value tuples from parameters.""" + + pass + + +def _changed_params(base_object): + """Return dict (param_name: value) of parameters with non-default values.""" + params = base_object.get_params(deep=False) + init_func = getattr( + base_object.__init__, "deprecated_original", base_object.__init__ + ) + init_params = inspect.signature(init_func).parameters + init_params = {name: param.default for name, param in init_params.items()} + + def has_changed(k, v): + if k not in init_params: # happens if k is part of a **kwargs + return True + if init_params[k] == inspect._empty: # k has no default value + return True + # try to avoid calling repr on nested BaseObjects + if isinstance(v, BaseObject) and v.__class__ != init_params[k].__class__: + return True + # Use repr as a last resort. It may be expensive. + if repr(v) != repr(init_params[k]) and not ( + _is_scalar_nan(init_params[k]) and _is_scalar_nan(v) + ): + return True + return False + + return {k: v for k, v in params.items() if has_changed(k, v)} + + +class _BaseObjectPrettyPrinter(pprint.PrettyPrinter): + """Pretty Printer class for BaseObjects. + + This extends the pprint.PrettyPrinter class similar to scikit-learn's + implementation, so that: + + - BaseObjects are printed with their parameters, e.g. + BaseObject(param1=value1, ...) which is not supported by default. + - the 'compact' parameter of PrettyPrinter is ignored for dicts, which + may lead to very long representations that we want to avoid. + + Quick overview of pprint.PrettyPrinter (see also + https://stackoverflow.com/questions/49565047/pprint-with-hex-numbers): + + - the entry point is the _format() method which calls format() (overridden + here) + - format() directly calls _safe_repr() for a first try at rendering the + object + - _safe_repr formats the whole object recursively, only calling itself, + not caring about line length or anything + - back to _format(), if the output string is too long, _format() then calls + the appropriate _pprint_TYPE() method (e.g. _pprint_list()) depending on + the type of the object. This where the line length and the compact + parameters are taken into account. + - those _pprint_TYPE() methods will internally use the format() method for + rendering the nested objects of an object (e.g. the elements of a list) + + In the end, everything has to be implemented twice: in _safe_repr and in + the custom _pprint_TYPE methods. Unfortunately PrettyPrinter is really not + straightforward to extend (especially when we want a compact output), so + the code is a bit convoluted. + + This class overrides: + - format() to support the changed_only parameter + - _safe_repr to support printing of BaseObjects that fit on a single line + - _format_dict_items so that dict are correctly 'compacted' + - _format_items so that ellipsis is used on long lists and tuples + + When BaseObjects cannot be printed on a single line, the builtin _format() + will call _pprint_object() because it was registered to do so (see + _dispatch[BaseObject.__repr__] = _pprint_object). + + both _format_dict_items() and _pprint_Object() use the + _format_params_or_dict_items() method that will format parameters and + key-value pairs respecting the compact parameter. This method needs another + subroutine _pprint_key_val_tuple() used when a parameter or a key-value + pair is too long to fit on a single line. This subroutine is called in + _format() and is registered as well in the _dispatch dict (just like + _pprint_object). We had to create the two classes KeyValTuple and + KeyValTupleParam for this. + """ + + def __init__( + self, + indent=1, + width=80, + depth=None, + stream=None, + *, + compact=False, + indent_at_name=True, + n_max_elements_to_show=None, + changed_only=True, + ): + super().__init__(indent, width, depth, stream, compact=compact) + self._indent_at_name = indent_at_name + if self._indent_at_name: + self._indent_per_level = 1 # ignore indent param + self.changed_only = changed_only + # Max number of elements in a list, dict, tuple until we start using + # ellipsis. This also affects the number of arguments of a BaseObject + # (they are treated as dicts) + self.n_max_elements_to_show = n_max_elements_to_show + + def format(self, obj, context, maxlevels, level): # noqa + return _safe_repr( + obj, context, maxlevels, level, changed_only=self.changed_only + ) + + def _pprint_object(self, obj, stream, indent, allowance, context, level): + stream.write(obj.__class__.__name__ + "(") + if self._indent_at_name: + indent += len(obj.__class__.__name__) + + if self.changed_only: + params = _changed_params(obj) + else: + params = obj.get_params(deep=False) + + params = OrderedDict((name, val) for (name, val) in sorted(params.items())) + + self._format_params( + params.items(), stream, indent, allowance + 1, context, level + ) + stream.write(")") + + def _format_dict_items(self, items, stream, indent, allowance, context, level): + return self._format_params_or_dict_items( + items, stream, indent, allowance, context, level, is_dict=True + ) + + def _format_params(self, items, stream, indent, allowance, context, level): + return self._format_params_or_dict_items( + items, stream, indent, allowance, context, level, is_dict=False + ) + + def _format_params_or_dict_items( + self, obj, stream, indent, allowance, context, level, is_dict + ): + """Format dict items or parameters respecting the compact=True parameter. + + For some reason, the builtin rendering of dict items doesn't + respect compact=True and will use one line per key-value if all cannot + fit in a single line. + Dict items will be rendered as <'key': value> while params will be + rendered as . The implementation is mostly copy/pasting from + the builtin _format_items(). + This also adds ellipsis if the number of items is greater than + self.n_max_elements_to_show. + """ + write = stream.write + indent += self._indent_per_level + delimnl = ",\n" + " " * indent + delim = "" + width = max_width = self._width - indent + 1 + it = iter(obj) + try: + next_ent = next(it) + except StopIteration: + return + last = False + n_items = 0 + while not last: + if n_items == self.n_max_elements_to_show: + write(", ...") + break + n_items += 1 + ent = next_ent + try: + next_ent = next(it) + except StopIteration: + last = True + max_width -= allowance + width -= allowance + if self._compact: + k, v = ent + krepr = self._repr(k, context, level) + vrepr = self._repr(v, context, level) + if not is_dict: + krepr = krepr.strip("'") + middle = ": " if is_dict else "=" + rep = krepr + middle + vrepr + w = len(rep) + 2 + if width < w: + width = max_width + if delim: + delim = delimnl + if width >= w: + width -= w + write(delim) + delim = ", " + write(rep) + continue + write(delim) + delim = delimnl + class_ = KeyValTuple if is_dict else KeyValTupleParam + self._format( + class_(ent), stream, indent, allowance if last else 1, context, level + ) + + def _format_items(self, items, stream, indent, allowance, context, level): + """Format the items of an iterable (list, tuple...). + + Same as the built-in _format_items, with support for ellipsis if the + number of elements is greater than self.n_max_elements_to_show. + """ + write = stream.write + indent += self._indent_per_level + if self._indent_per_level > 1: + write((self._indent_per_level - 1) * " ") + delimnl = ",\n" + " " * indent + delim = "" + width = max_width = self._width - indent + 1 + it = iter(items) + try: + next_ent = next(it) + except StopIteration: + return + last = False + n_items = 0 + while not last: + if n_items == self.n_max_elements_to_show: + write(", ...") + break + n_items += 1 + ent = next_ent + try: + next_ent = next(it) + except StopIteration: + last = True + max_width -= allowance + width -= allowance + if self._compact: + rep = self._repr(ent, context, level) + w = len(rep) + 2 + if width < w: + width = max_width + if delim: + delim = delimnl + if width >= w: + width -= w + write(delim) + delim = ", " + write(rep) + continue + write(delim) + delim = delimnl + self._format(ent, stream, indent, allowance if last else 1, context, level) + + def _pprint_key_val_tuple(self, obj, stream, indent, allowance, context, level): + """Pretty printing for key-value tuples from dict or parameters.""" + k, v = obj + rep = self._repr(k, context, level) + if isinstance(obj, KeyValTupleParam): + rep = rep.strip("'") + middle = "=" + else: + middle = ": " + stream.write(rep) + stream.write(middle) + self._format( + v, stream, indent + len(rep) + len(middle), allowance, context, level + ) + + # Follow what scikit-learn did here and copy _dispatch to prevent instances + # of the builtin PrettyPrinter class to call methods of + # _BaseObjectPrettyPrinter (see scikit-learn Github issue 12906) + # mypy error: "Type[PrettyPrinter]" has no attribute "_dispatch" + _dispatch = pprint.PrettyPrinter._dispatch.copy() # type: ignore + _dispatch[BaseObject.__repr__] = _pprint_object + _dispatch[KeyValTuple.__repr__] = _pprint_key_val_tuple + + +def _safe_repr(obj, context, maxlevels, level, changed_only=False): + """Safe string representation logic. + + Same as the builtin _safe_repr, with added support for BaseObjects. + """ + typ = type(obj) + + if typ in pprint._builtin_scalars: + return repr(obj), True, False + + r = getattr(typ, "__repr__", None) + if issubclass(typ, dict) and r is dict.__repr__: + if not obj: + return "{}", True, False + objid = id(obj) + if maxlevels and level >= maxlevels: + return "{...}", False, objid in context + if objid in context: + return pprint._recursion(obj), False, True + context[objid] = 1 + readable = True + recursive = False + components = [] + append = components.append + level += 1 + saferepr = _safe_repr + items = sorted(obj.items(), key=pprint._safe_tuple) + for k, v in items: + krepr, kreadable, krecur = saferepr( + k, context, maxlevels, level, changed_only=changed_only + ) + vrepr, vreadable, vrecur = saferepr( + v, context, maxlevels, level, changed_only=changed_only + ) + append("%s: %s" % (krepr, vrepr)) + readable = readable and kreadable and vreadable + if krecur or vrecur: + recursive = True + del context[objid] + return "{%s}" % ", ".join(components), readable, recursive + + if (issubclass(typ, list) and r is list.__repr__) or ( + issubclass(typ, tuple) and r is tuple.__repr__ + ): + if issubclass(typ, list): + if not obj: + return "[]", True, False + format_ = "[%s]" + elif len(obj) == 1: + format_ = "(%s,)" + else: + if not obj: + return "()", True, False + format_ = "(%s)" + objid = id(obj) + if maxlevels and level >= maxlevels: + return format_ % "...", False, objid in context + if objid in context: + return pprint._recursion(obj), False, True + context[objid] = 1 + readable = True + recursive = False + components = [] + append = components.append + level += 1 + for o in obj: + orepr, oreadable, orecur = _safe_repr( + o, context, maxlevels, level, changed_only=changed_only + ) + append(orepr) + if not oreadable: + readable = False + if orecur: + recursive = True + del context[objid] + return format_ % ", ".join(components), readable, recursive + + if issubclass(typ, BaseObject): + objid = id(obj) + if maxlevels and level >= maxlevels: + return "{...}", False, objid in context + if objid in context: + return pprint._recursion(obj), False, True + context[objid] = 1 + readable = True + recursive = False + if changed_only: + params = _changed_params(obj) + else: + params = obj.get_params(deep=False) + components = [] + append = components.append + level += 1 + saferepr = _safe_repr + items = sorted(params.items(), key=pprint._safe_tuple) + for k, v in items: + krepr, kreadable, krecur = saferepr( + k, context, maxlevels, level, changed_only=changed_only + ) + vrepr, vreadable, vrecur = saferepr( + v, context, maxlevels, level, changed_only=changed_only + ) + append("%s=%s" % (krepr.strip("'"), vrepr)) + readable = readable and kreadable and vreadable + if krecur or vrecur: + recursive = True + del context[objid] + return ("%s(%s)" % (typ.__name__, ", ".join(components)), readable, recursive) + + rep = repr(obj) + return rep, (rep and not rep.startswith("<")), False diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py new file mode 100644 index 00000000..a438b9ca --- /dev/null +++ b/skbase/config/__init__.py @@ -0,0 +1,393 @@ +# -*- coding: utf-8 -*- +""":mod:`skbase.config` provides tools for global configuration of ``skbase``.""" +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Includes functionality like get_config, set_config, and config_context +# that is similar to scikit-learn. These elements are copyrighted by the +# scikit-learn developers, BSD-3-Clause License. For conditions see +# https://github.com/scikit-learn/scikit-learn/blob/main/COPYING + +import sys +import threading +import warnings +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal # type: ignore + +__author__: List[str] = ["RNKuhns"] + + +@dataclass +class GlobalConfigParamSetting: + """Define types of the setting information for a given config parameter.""" + + os_environ_name: str + expected_type: Union[type, Tuple[type]] + allowed_values: Optional[Tuple[Any, ...]] + default_value: Any + + def get_allowed_values(self): + """Get `allowed_values` or empty tuple if `allowed_values` is None. + + Returns + ------- + tuple + Allowable values if any. + """ + if self.allowed_values is None: + return () + else: + return self.allowed_values + + def is_valid_param_value(self, value): + """Validate that a global configuration value is valid. + + Verifies that the value set for a global configuration parameter is valid + based on the its configuration settings. + + Returns + ------- + bool + Whether a parameter value is valid. + """ + if ( + isinstance(self.allowed_values, (list, tuple)) + or self.allowed_values is None + ): + allowed_values = self.allowed_values + else: + allowed_values = (self.allowed_values,) + + valid_param: bool + if not isinstance(value, self.expected_type): + valid_param = False + elif allowed_values is not None and value not in allowed_values: + valid_param = False + else: + valid_param = True + return valid_param + + def get_valid_param_or_default(self, value): + """Validate `value` and return default if it is not valid.""" + if self.is_valid_param_value(value): + return value + else: + msg = f"{self.os_environ_name} should be one of " + msg += f"{', '.join(map(repr, self.allowed_values))}." + msg += "Using default value for this configuration as a result." + warnings.warn(msg, UserWarning, stacklevel=2) + return self.default_value + + +_CONFIG_REGISTRY: Dict[str, GlobalConfigParamSetting] = { + "print_changed_only": GlobalConfigParamSetting( + os_environ_name="SKBASE_PRINT_CHANGED_ONLY", + expected_type=bool, + allowed_values=(True, False), + default_value=True, + ), + "display": GlobalConfigParamSetting( + os_environ_name="SKBASE_OBJECT_DISPLAY", + expected_type=str, + allowed_values=("text", "diagram"), + default_value="text", + ), +} + +_DEFAULT_GLOBAL_CONFIG: Dict[str, Any] = { + config_name: config_info.default_value + for config_name, config_info in _CONFIG_REGISTRY.items() +} + +global_config = _DEFAULT_GLOBAL_CONFIG.copy() +_THREAD_LOCAL_DATA = threading.local() + + +def _get_threadlocal_config() -> Dict[str, Any]: + """Get a threadlocal **mutable** configuration. + + If the configuration does not exist, copy the default global configuration. + + Returns + ------- + dict + Threadlocal global config or copy of default global configuration. + """ + if not hasattr(_THREAD_LOCAL_DATA, "global_config"): + _THREAD_LOCAL_DATA.global_config = global_config.copy() + return _THREAD_LOCAL_DATA.global_config + + +def get_config_os_env_names() -> List[str]: + """Retrieve the os environment names for configurable settings. + + Returns + ------- + env_names : list + The os environment names that can be used to set configurable settings. + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current global configuration values. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Examples + -------- + >>> from skbase.config import get_config_os_env_names + >>> get_config_os_env_names() + ['SKBASE_PRINT_CHANGED_ONLY', 'SKBASE_OBJECT_DISPLAY'] + """ + return [config_info.os_environ_name for config_info in _CONFIG_REGISTRY.values()] + + +def get_default_config() -> Dict[str, Any]: + """Retrive the default global configuration. + + This will always return the default ``skbase`` global configuration. + + Returns + ------- + config : dict + The default configurable settings (keys) and their default values (values). + + See Also + -------- + config_context : + Configuration context manager. + get_config : + Retrieve current global configuration values. + get_config_os_env_names : + Retrieve os environment names that can be used to set configuration. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Examples + -------- + >>> from skbase.config import get_default_config + >>> get_default_config() + {'print_changed_only': True, 'display': 'text'} + """ + return _DEFAULT_GLOBAL_CONFIG.copy() + + +def get_config() -> Dict[str, Any]: + """Retrieve current values for configuration set by :meth:`set_config`. + + Willr return the default configuration if know updated configuration has + been set by :meth:`set_config`. + + Returns + ------- + config : dict + The configurable settings (keys) and their default values (values). + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config_os_env_names : + Retrieve os environment names that can be used to set configuration. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Examples + -------- + >>> from skbase.config import get_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + """ + return _get_threadlocal_config().copy() + + +def set_config( + print_changed_only: Optional[bool] = None, + display: Literal["text", "diagram"] = None, + local_threadsafe: bool = False, +) -> None: + """Set global configuration. + + Allows the ``skbase`` global configuration to be updated. + + Parameters + ---------- + print_changed_only : bool, default=None + If True, only the parameters that were set to non-default + values will be printed when printing a BaseObject instance. For example, + ``print(SVC())`` while True will only print 'SVC()', but would print + 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters + when False. If None, the existing value won't change. + display : {'text', 'diagram'}, default=None + If 'diagram', instances inheritting from BaseOBject will be displayed + as a diagram in a Jupyter lab or notebook context. If 'text', instances + inheritting from BaseObject will be displayed as text. If None, the + existing value won't change. + local_threadsafe : bool, default=False + If False, set the backend as default for all threads. + + Returns + ------- + None + No output returned. + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current global configuration values. + get_config_os_env_names : + Retrieve os environment names that can be used to set configuration. + reset_config : + Reset configuration to default. + + Examples + -------- + >>> from skbase.config import get_config, set_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + >>> set_config(display='diagram') + >>> get_config() + {'print_changed_only': True, 'display': 'diagram'} + """ + local_config = _get_threadlocal_config() + + if print_changed_only is not None: + local_config["print_changed_only"] = _CONFIG_REGISTRY[ + "print_changed_only" + ].get_valid_param_or_default(print_changed_only) + if display is not None: + local_config["display"] = _CONFIG_REGISTRY[ + "display" + ].get_valid_param_or_default(display) + + if not local_threadsafe: + global_config.update(local_config) + + return None + + +def reset_config() -> None: + """Reset the global configuration to the default. + + Will remove any user updates to the global configuration and reset the values + back to the ``skbase`` defaults. + + Returns + ------- + None + No output returned. + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current global configuration values. + get_config_os_env_names : + Retrieve os environment names that can be used to set configuration. + set_config : + Set global configuration. + + Examples + -------- + >>> from skbase.config import get_config, set_config, reset_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + >>> set_config(display='diagram') + >>> get_config() + {'print_changed_only': True, 'display': 'diagram'} + >>> reset_config() + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + """ + default_config = get_default_config() + set_config(**default_config) + return None + + +@contextmanager +def config_context( + print_changed_only: Optional[bool] = None, + display: Literal["text", "diagram"] = None, + local_threadsafe: bool = False, +) -> Iterator[None]: + """Context manager for global configuration. + + Parameters + ---------- + print_changed_only : bool, default=None + If True, only the parameters that were set to non-default + values will be printed when printing a BaseObject instance. For example, + ``print(SVC())`` while True will only print 'SVC()', but would print + 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters + when False. If None, the existing value won't change. + display : {'text', 'diagram'}, default=None + If 'diagram', instances inheritting from BaseOBject will be displayed + as a diagram in a Jupyter lab or notebook context. If 'text', instances + inheritting from BaseObject will be displayed as text. If None, the + existing value won't change. + local_threadsafe : bool, default=False + If False, set the config as default for all threads. + + Yields + ------ + None + + See Also + -------- + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current values of the global configuration. + get_config_os_env_names : + Retrieve os environment names that can be used to set configuration. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Notes + ----- + All settings, not just those presently modified, will be returned to + their previous values when the context manager is exited. + + Examples + -------- + >>> from skbase.config import config_context + >>> with config_context(display='diagram'): + ... pass + """ + old_config = get_config() + set_config( + print_changed_only=print_changed_only, + display=display, + local_threadsafe=local_threadsafe, + ) + + try: + yield + finally: + set_config(**old_config) diff --git a/skbase/config/tests/__init__.py b/skbase/config/tests/__init__.py new file mode 100644 index 00000000..ac71ad0c --- /dev/null +++ b/skbase/config/tests/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python3 -u +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +"""Test functionality of :mod:`skbase.config`.""" diff --git a/skbase/config/tests/test_config.py b/skbase/config/tests/test_config.py new file mode 100644 index 00000000..aff880c2 --- /dev/null +++ b/skbase/config/tests/test_config.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +"""Test configuration functionality.""" +import pytest + +from skbase.config import ( + _CONFIG_REGISTRY, + _DEFAULT_GLOBAL_CONFIG, + config_context, + get_config, + get_config_os_env_names, + get_default_config, + reset_config, + set_config, +) + +PRINT_CHANGE_ONLY_VALUES = _CONFIG_REGISTRY["print_changed_only"].get_allowed_values() +DISPLAY_VALUES = _CONFIG_REGISTRY["display"].get_allowed_values() + + +@pytest.fixture +def config_registry(): + """Config registry fixture.""" + return _CONFIG_REGISTRY + + +@pytest.fixture +def global_config_default(): + """Config registry fixture.""" + return _DEFAULT_GLOBAL_CONFIG + + +def test_get_config_os_env_names(config_registry): + """Verify that get_config_os_env_names returns expected values.""" + os_env_names = get_config_os_env_names() + expected_os_env_names = [ + config_info.os_environ_name for config_info in config_registry.values() + ] + msg = "`get_config_os_env_names` does not return expected value.\n" + msg += f"Expected {expected_os_env_names}, but returned {os_env_names}." + assert os_env_names == expected_os_env_names, msg + + +def test_get_default_config(global_config_default): + """Verify that get_default_config returns default global config values.""" + retrieved_default = get_default_config() + msg = "`get_default_config` does not return expected values.\n" + msg += f"Expected {global_config_default}, but returned {retrieved_default}." + assert retrieved_default == global_config_default, msg + + +@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) +@pytest.mark.parametrize("display", DISPLAY_VALUES) +def test_set_config_then_get_config_returns_expected_value(print_changed_only, display): + """Verify that get_config returns set config values if set_config run.""" + set_config(print_changed_only=print_changed_only, display=display) + retrieved_default = get_config() + expected_config = {"print_changed_only": print_changed_only, "display": display} + msg = "`get_config` used after `set_config` does not return expected values.\n" + msg += "After set_config is run, get_config should return the set values.\n " + msg += f"Expected {expected_config}, but returned {retrieved_default}." + assert retrieved_default == expected_config, msg + + +@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) +@pytest.mark.parametrize("display", DISPLAY_VALUES) +def test_reset_config_resets_the_config(print_changed_only, display): + """Verify that get_config returns default config if reset_config run.""" + default_config = get_default_config() + set_config(print_changed_only=print_changed_only, display=display) + reset_config() + retrieved_config = get_config() + + msg = "`get_config` does not return expected values after `reset_config`.\n" + msg += "`After reset_config is run, get_config` should return defaults.\n" + msg += f"Expected {default_config}, but returned {retrieved_config}." + assert retrieved_config == default_config, msg + + +@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) +@pytest.mark.parametrize("display", DISPLAY_VALUES) +def test_config_context(print_changed_only, display): + """Verify that config_context affects context but not overall configuration.""" + # Make sure config is reset to default values then retrieve it + reset_config() + retrieved_config = get_config() + # Now lets make sure the config_context is changing the context of those values + # within the scope of the context manager as expected + for print_changed_only in (True, False): + with config_context(print_changed_only=print_changed_only, display=display): + retrieved_context_config = get_config() + expected_config = {"print_changed_only": print_changed_only, "display": display} + msg = "`get_config` does not return expected values within `config_context`.\n" + msg += "`get_config` should return config defined by `config_context`.\n" + msg += f"Expected {expected_config}, but returned {retrieved_context_config}." + assert retrieved_context_config == expected_config, msg + + # Outside of the config_context we should have not affected the retrieved config + # set by call to reset_config() + config_post_config_context = get_config() + msg = "`get_config` does not return expected values after `config_context`a.\n" + msg += "`config_context` should not affect configuration outside its context.\n" + msg += f"Expected {config_post_config_context}, but returned {retrieved_config}." + assert retrieved_config == config_post_config_context, msg diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 4f36034a..24a92cf2 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -23,6 +23,7 @@ "skbase.base._base", "skbase.base._meta", "skbase.base._tagmanager", + "skbase.config", "skbase.lookup", "skbase.lookup.tests", "skbase.lookup.tests.test_lookup", @@ -42,6 +43,7 @@ "skbase.tests.test_baseestimator", "skbase.tests.mock_package.test_mock_package", "skbase.utils", + "skbase.utils._check", "skbase.utils._iter", "skbase.utils._nested_iter", "skbase.validate", @@ -51,6 +53,7 @@ SKBASE_PUBLIC_MODULES = ( "skbase", "skbase.base", + "skbase.config", "skbase.lookup", "skbase.lookup.tests", "skbase.lookup.tests.test_lookup", @@ -74,6 +77,7 @@ "skbase.base": ("BaseEstimator", "BaseMetaEstimator", "BaseObject"), "skbase.base._base": ("BaseEstimator", "BaseObject"), "skbase.base._meta": ("BaseMetaEstimator",), + "skbase.config": ("GlobalConfigParamSetting",), "skbase.lookup._lookup": ("ClassInfo", "FunctionInfo", "ModuleInfo"), "skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"), "skbase.testing.test_all_objects": ( @@ -85,11 +89,18 @@ SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy() SKBASE_CLASSES_BY_MODULE.update( { - "skbase.base._meta": ("BaseMetaEstimator",), "skbase.base._tagmanager": ("_FlagManager",), } ) SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = { + "skbase.config": ( + "config_context", + "get_config", + "get_default_config", + "get_config_os_env_names", + "reset_config", + "set_config", + ), "skbase.lookup": ("all_objects", "get_package_metadata"), "skbase.lookup._lookup": ("all_objects", "get_package_metadata"), "skbase.testing.utils._conditional_fixtures": ( @@ -124,6 +135,15 @@ SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy() SKBASE_FUNCTIONS_BY_MODULE.update( { + "skbase.config": ( + "_get_threadlocal_config", + "config_context", + "get_config", + "get_default_config", + "get_config_os_env_names", + "reset_config", + "set_config", + ), "skbase.lookup._lookup": ( "_determine_module_path", "_get_return_tags", @@ -155,6 +175,7 @@ "_coerce_list", ), "skbase.testing.utils.inspect": ("_get_args",), + "skbase.utils._check": ("_is_scalar_nan",), "skbase.utils._iter": ( "_format_seq_to_str", "_remove_type_text", diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index 38d2d5a9..913b5a8f 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -73,12 +73,12 @@ import numpy as np import pytest import scipy.sparse as sp -from sklearn import config_context # TODO: Update with import of skbase clone function once implemented from sklearn.base import clone from skbase.base import BaseEstimator, BaseObject +from skbase.config import config_context # type: ignore from skbase.tests.conftest import Child, Parent from skbase.tests.mock_package.test_mock_package import CompositionDummy @@ -878,12 +878,13 @@ def test_baseobject_repr(fixture_class_parent, fixture_composition_dummy): """Test BaseObject repr works as expected.""" # Simple test where all parameters are left at defaults # Should not see parameters and values in printed representation + base_obj = fixture_class_parent() assert repr(base_obj) == "Parent()" # Check that we can alter the detail about params that is printed # using config_context with ``print_changed_only=False`` - with config_context(print_changed_only=False): + with config_context(print_changed_only=False, display="text"): assert repr(base_obj) == "Parent(a='something', b=7, c=None)" simple_composite = fixture_composition_dummy(foo=fixture_class_parent()) diff --git a/skbase/utils/_check.py b/skbase/utils/_check.py new file mode 100644 index 00000000..7a6830c4 --- /dev/null +++ b/skbase/utils/_check.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Elements of _is_scalar_nan re-use code developed in scikit-learn. These elements +# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For +# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING + +"""Utility functions to perform various types of checks.""" +from __future__ import annotations + +import math +import numbers +from typing import Any + +__all__ = ["_is_scalar_nan"] +__author__ = ["RNKuhns"] + + +def _is_scalar_nan(x: Any) -> bool: + """Test if x is NaN. + + This function is meant to overcome the issue that np.isnan does not allow + non-numerical types as input, and that np.nan is not float('nan'). + + Parameters + ---------- + x : Any + The item to be checked to determine if it is a scalar nan value. + + Returns + ------- + bool + True if `x` is a scalar nan value + + Notes + ----- + This code follows scikit-learn's implementation. + + Examples + -------- + >>> import numpy as np + >>> from skbase.utils._check import _is_scalar_nan + >>> _is_scalar_nan(np.nan) + True + >>> _is_scalar_nan(float("nan")) + True + >>> _is_scalar_nan(None) + False + >>> _is_scalar_nan("") + False + >>> _is_scalar_nan([np.nan]) + False + """ + return isinstance(x, numbers.Real) and math.isnan(x) diff --git a/skbase/utils/tests/test_check.py b/skbase/utils/tests/test_check.py new file mode 100644 index 00000000..f998bae4 --- /dev/null +++ b/skbase/utils/tests/test_check.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 -u +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +"""Tests of the utility functionality for performing various checks. + +tests in this module include: + +- test_is_scalar_nan_output to verify _is_scalar_nan outputs expected value for + different inputs. +""" +import numpy as np + +from skbase.utils._check import _is_scalar_nan + +__author__ = ["RNKuhns"] + + +def test_is_scalar_nan_output(): + """Test that _is_scalar_nan outputs expected value for different inputs.""" + assert _is_scalar_nan(np.nan) is True + assert _is_scalar_nan(float("nan")) is True + assert _is_scalar_nan(None) is False + assert _is_scalar_nan("") is False + assert _is_scalar_nan([np.nan]) is False From aa166f2c5e6153caa8390f329be8ec165843981f Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Sun, 12 Mar 2023 21:19:02 -0400 Subject: [PATCH 02/16] Minor improvements to config test coverage --- skbase/config/__init__.py | 35 +++++++++++++------- skbase/config/tests/test_config.py | 51 ++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 12 deletions(-) diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py index a438b9ca..3f38839b 100644 --- a/skbase/config/__init__.py +++ b/skbase/config/__init__.py @@ -6,7 +6,7 @@ # that is similar to scikit-learn. These elements are copyrighted by the # scikit-learn developers, BSD-3-Clause License. For conditions see # https://github.com/scikit-learn/scikit-learn/blob/main/COPYING - +import collections import sys import threading import warnings @@ -19,6 +19,8 @@ else: from typing_extensions import Literal # type: ignore +from skbase.utils._iter import _format_seq_to_str + __author__: List[str] = ["RNKuhns"] @@ -26,6 +28,7 @@ class GlobalConfigParamSetting: """Define types of the setting information for a given config parameter.""" + name: str os_environ_name: str expected_type: Union[type, Tuple[type]] allowed_values: Optional[Tuple[Any, ...]] @@ -41,8 +44,14 @@ def get_allowed_values(self): """ if self.allowed_values is None: return () - else: + elif isinstance(self.allowed_values, tuple): return self.allowed_values + elif isinstance( + self.allowed_values, collections.abc.Iterable + ) and not isinstance(self.allowed_values, str): + return tuple(self.allowed_values) + else: + return (self.allowed_values,) def is_valid_param_value(self, value): """Validate that a global configuration value is valid. @@ -55,13 +64,7 @@ def is_valid_param_value(self, value): bool Whether a parameter value is valid. """ - if ( - isinstance(self.allowed_values, (list, tuple)) - or self.allowed_values is None - ): - allowed_values = self.allowed_values - else: - allowed_values = (self.allowed_values,) + allowed_values = self.get_allowed_values() valid_param: bool if not isinstance(value, self.expected_type): @@ -77,21 +80,29 @@ def get_valid_param_or_default(self, value): if self.is_valid_param_value(value): return value else: - msg = f"{self.os_environ_name} should be one of " - msg += f"{', '.join(map(repr, self.allowed_values))}." - msg += "Using default value for this configuration as a result." + msg = "Attempting to set an invalid value for a global configuration.\n" + msg += "Using default configuration value of parameter as a result.\n" + msg + f"When setting global config values for `{self.name}`, the values " + msg += f"should be of type {self.expected_type}." + if self.allowed_values is not None: + values_str = _format_seq_to_str( + self.get_allowed_values(), last_sep="or", remove_type_text=True + ) + msg += f"Allowed values are be one of {values_str}." warnings.warn(msg, UserWarning, stacklevel=2) return self.default_value _CONFIG_REGISTRY: Dict[str, GlobalConfigParamSetting] = { "print_changed_only": GlobalConfigParamSetting( + name="print_changed_only", os_environ_name="SKBASE_PRINT_CHANGED_ONLY", expected_type=bool, allowed_values=(True, False), default_value=True, ), "display": GlobalConfigParamSetting( + name="display", os_environ_name="SKBASE_OBJECT_DISPLAY", expected_type=str, allowed_values=("text", "diagram"), diff --git a/skbase/config/tests/test_config.py b/skbase/config/tests/test_config.py index aff880c2..67a104ff 100644 --- a/skbase/config/tests/test_config.py +++ b/skbase/config/tests/test_config.py @@ -5,6 +5,7 @@ from skbase.config import ( _CONFIG_REGISTRY, _DEFAULT_GLOBAL_CONFIG, + GlobalConfigParamSetting, config_context, get_config, get_config_os_env_names, @@ -29,6 +30,39 @@ def global_config_default(): return _DEFAULT_GLOBAL_CONFIG +@pytest.mark.parametrize("allowed_values", (None, (), "something", range(1, 8))) +def test_global_config_param_get_allowed_values(allowed_values): + """Test GlobalConfigParamSetting behavior works as expected.""" + some_config_param = GlobalConfigParamSetting( + name="some_param", + os_environ_name="SKBASE_OBJECT_DISPLAY", + expected_type=str, + allowed_values=allowed_values, + default_value="text", + ) + # Verify we always coerce output of get_allowed_values to tuple + values = some_config_param.get_allowed_values() + assert isinstance(values, tuple) + + +@pytest.mark.parametrize("value", (None, (), "wrong_string", "text", range(1, 8))) +def test_global_config_param_is_valid_param_value(value): + """Test GlobalConfigParamSetting behavior works as expected.""" + some_config_param = GlobalConfigParamSetting( + name="some_param", + os_environ_name="SKBASE_OBJECT_DISPLAY", + expected_type=str, + allowed_values=("text", "diagram"), + default_value="text", + ) + # Verify we correctly identify invalid parameters + if value in ("text", "diagram"): + expected_valid = True + else: + expected_valid = False + assert some_config_param.is_valid_param_value(value) == expected_valid + + def test_get_config_os_env_names(config_registry): """Verify that get_config_os_env_names returns expected values.""" os_env_names = get_config_os_env_names() @@ -74,6 +108,7 @@ def test_reset_config_resets_the_config(print_changed_only, display): msg += "`After reset_config is run, get_config` should return defaults.\n" msg += f"Expected {default_config}, but returned {retrieved_config}." assert retrieved_config == default_config, msg + reset_config() @pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) @@ -101,3 +136,19 @@ def test_config_context(print_changed_only, display): msg += "`config_context` should not affect configuration outside its context.\n" msg += f"Expected {config_post_config_context}, but returned {retrieved_config}." assert retrieved_config == config_post_config_context, msg + reset_config() + + +def test_set_config_behavior_invalid_value(): + """Test set_config uses default and raises warning when setting invalid value.""" + reset_config() + with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"): + set_config(print_changed_only="False") + + assert get_config() == get_default_config() + + with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"): + set_config(print_changed_only=7) + + assert get_config() == get_default_config() + reset_config() From 5ad1707e9e692b6b4bc773c50d476e2ef2d89084 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Sun, 12 Mar 2023 21:21:21 -0400 Subject: [PATCH 03/16] Add global config --- skbase/config/__init__.py | 404 +++++++++++++++++++++++++++++ skbase/config/tests/__init__.py | 4 + skbase/config/tests/test_config.py | 154 +++++++++++ 3 files changed, 562 insertions(+) create mode 100644 skbase/config/__init__.py create mode 100644 skbase/config/tests/__init__.py create mode 100644 skbase/config/tests/test_config.py diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py new file mode 100644 index 00000000..3f38839b --- /dev/null +++ b/skbase/config/__init__.py @@ -0,0 +1,404 @@ +# -*- coding: utf-8 -*- +""":mod:`skbase.config` provides tools for global configuration of ``skbase``.""" +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Includes functionality like get_config, set_config, and config_context +# that is similar to scikit-learn. These elements are copyrighted by the +# scikit-learn developers, BSD-3-Clause License. For conditions see +# https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +import collections +import sys +import threading +import warnings +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal # type: ignore + +from skbase.utils._iter import _format_seq_to_str + +__author__: List[str] = ["RNKuhns"] + + +@dataclass +class GlobalConfigParamSetting: + """Define types of the setting information for a given config parameter.""" + + name: str + os_environ_name: str + expected_type: Union[type, Tuple[type]] + allowed_values: Optional[Tuple[Any, ...]] + default_value: Any + + def get_allowed_values(self): + """Get `allowed_values` or empty tuple if `allowed_values` is None. + + Returns + ------- + tuple + Allowable values if any. + """ + if self.allowed_values is None: + return () + elif isinstance(self.allowed_values, tuple): + return self.allowed_values + elif isinstance( + self.allowed_values, collections.abc.Iterable + ) and not isinstance(self.allowed_values, str): + return tuple(self.allowed_values) + else: + return (self.allowed_values,) + + def is_valid_param_value(self, value): + """Validate that a global configuration value is valid. + + Verifies that the value set for a global configuration parameter is valid + based on the its configuration settings. + + Returns + ------- + bool + Whether a parameter value is valid. + """ + allowed_values = self.get_allowed_values() + + valid_param: bool + if not isinstance(value, self.expected_type): + valid_param = False + elif allowed_values is not None and value not in allowed_values: + valid_param = False + else: + valid_param = True + return valid_param + + def get_valid_param_or_default(self, value): + """Validate `value` and return default if it is not valid.""" + if self.is_valid_param_value(value): + return value + else: + msg = "Attempting to set an invalid value for a global configuration.\n" + msg += "Using default configuration value of parameter as a result.\n" + msg + f"When setting global config values for `{self.name}`, the values " + msg += f"should be of type {self.expected_type}." + if self.allowed_values is not None: + values_str = _format_seq_to_str( + self.get_allowed_values(), last_sep="or", remove_type_text=True + ) + msg += f"Allowed values are be one of {values_str}." + warnings.warn(msg, UserWarning, stacklevel=2) + return self.default_value + + +_CONFIG_REGISTRY: Dict[str, GlobalConfigParamSetting] = { + "print_changed_only": GlobalConfigParamSetting( + name="print_changed_only", + os_environ_name="SKBASE_PRINT_CHANGED_ONLY", + expected_type=bool, + allowed_values=(True, False), + default_value=True, + ), + "display": GlobalConfigParamSetting( + name="display", + os_environ_name="SKBASE_OBJECT_DISPLAY", + expected_type=str, + allowed_values=("text", "diagram"), + default_value="text", + ), +} + +_DEFAULT_GLOBAL_CONFIG: Dict[str, Any] = { + config_name: config_info.default_value + for config_name, config_info in _CONFIG_REGISTRY.items() +} + +global_config = _DEFAULT_GLOBAL_CONFIG.copy() +_THREAD_LOCAL_DATA = threading.local() + + +def _get_threadlocal_config() -> Dict[str, Any]: + """Get a threadlocal **mutable** configuration. + + If the configuration does not exist, copy the default global configuration. + + Returns + ------- + dict + Threadlocal global config or copy of default global configuration. + """ + if not hasattr(_THREAD_LOCAL_DATA, "global_config"): + _THREAD_LOCAL_DATA.global_config = global_config.copy() + return _THREAD_LOCAL_DATA.global_config + + +def get_config_os_env_names() -> List[str]: + """Retrieve the os environment names for configurable settings. + + Returns + ------- + env_names : list + The os environment names that can be used to set configurable settings. + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current global configuration values. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Examples + -------- + >>> from skbase.config import get_config_os_env_names + >>> get_config_os_env_names() + ['SKBASE_PRINT_CHANGED_ONLY', 'SKBASE_OBJECT_DISPLAY'] + """ + return [config_info.os_environ_name for config_info in _CONFIG_REGISTRY.values()] + + +def get_default_config() -> Dict[str, Any]: + """Retrive the default global configuration. + + This will always return the default ``skbase`` global configuration. + + Returns + ------- + config : dict + The default configurable settings (keys) and their default values (values). + + See Also + -------- + config_context : + Configuration context manager. + get_config : + Retrieve current global configuration values. + get_config_os_env_names : + Retrieve os environment names that can be used to set configuration. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Examples + -------- + >>> from skbase.config import get_default_config + >>> get_default_config() + {'print_changed_only': True, 'display': 'text'} + """ + return _DEFAULT_GLOBAL_CONFIG.copy() + + +def get_config() -> Dict[str, Any]: + """Retrieve current values for configuration set by :meth:`set_config`. + + Willr return the default configuration if know updated configuration has + been set by :meth:`set_config`. + + Returns + ------- + config : dict + The configurable settings (keys) and their default values (values). + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config_os_env_names : + Retrieve os environment names that can be used to set configuration. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Examples + -------- + >>> from skbase.config import get_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + """ + return _get_threadlocal_config().copy() + + +def set_config( + print_changed_only: Optional[bool] = None, + display: Literal["text", "diagram"] = None, + local_threadsafe: bool = False, +) -> None: + """Set global configuration. + + Allows the ``skbase`` global configuration to be updated. + + Parameters + ---------- + print_changed_only : bool, default=None + If True, only the parameters that were set to non-default + values will be printed when printing a BaseObject instance. For example, + ``print(SVC())`` while True will only print 'SVC()', but would print + 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters + when False. If None, the existing value won't change. + display : {'text', 'diagram'}, default=None + If 'diagram', instances inheritting from BaseOBject will be displayed + as a diagram in a Jupyter lab or notebook context. If 'text', instances + inheritting from BaseObject will be displayed as text. If None, the + existing value won't change. + local_threadsafe : bool, default=False + If False, set the backend as default for all threads. + + Returns + ------- + None + No output returned. + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current global configuration values. + get_config_os_env_names : + Retrieve os environment names that can be used to set configuration. + reset_config : + Reset configuration to default. + + Examples + -------- + >>> from skbase.config import get_config, set_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + >>> set_config(display='diagram') + >>> get_config() + {'print_changed_only': True, 'display': 'diagram'} + """ + local_config = _get_threadlocal_config() + + if print_changed_only is not None: + local_config["print_changed_only"] = _CONFIG_REGISTRY[ + "print_changed_only" + ].get_valid_param_or_default(print_changed_only) + if display is not None: + local_config["display"] = _CONFIG_REGISTRY[ + "display" + ].get_valid_param_or_default(display) + + if not local_threadsafe: + global_config.update(local_config) + + return None + + +def reset_config() -> None: + """Reset the global configuration to the default. + + Will remove any user updates to the global configuration and reset the values + back to the ``skbase`` defaults. + + Returns + ------- + None + No output returned. + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current global configuration values. + get_config_os_env_names : + Retrieve os environment names that can be used to set configuration. + set_config : + Set global configuration. + + Examples + -------- + >>> from skbase.config import get_config, set_config, reset_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + >>> set_config(display='diagram') + >>> get_config() + {'print_changed_only': True, 'display': 'diagram'} + >>> reset_config() + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + """ + default_config = get_default_config() + set_config(**default_config) + return None + + +@contextmanager +def config_context( + print_changed_only: Optional[bool] = None, + display: Literal["text", "diagram"] = None, + local_threadsafe: bool = False, +) -> Iterator[None]: + """Context manager for global configuration. + + Parameters + ---------- + print_changed_only : bool, default=None + If True, only the parameters that were set to non-default + values will be printed when printing a BaseObject instance. For example, + ``print(SVC())`` while True will only print 'SVC()', but would print + 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters + when False. If None, the existing value won't change. + display : {'text', 'diagram'}, default=None + If 'diagram', instances inheritting from BaseOBject will be displayed + as a diagram in a Jupyter lab or notebook context. If 'text', instances + inheritting from BaseObject will be displayed as text. If None, the + existing value won't change. + local_threadsafe : bool, default=False + If False, set the config as default for all threads. + + Yields + ------ + None + + See Also + -------- + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current values of the global configuration. + get_config_os_env_names : + Retrieve os environment names that can be used to set configuration. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Notes + ----- + All settings, not just those presently modified, will be returned to + their previous values when the context manager is exited. + + Examples + -------- + >>> from skbase.config import config_context + >>> with config_context(display='diagram'): + ... pass + """ + old_config = get_config() + set_config( + print_changed_only=print_changed_only, + display=display, + local_threadsafe=local_threadsafe, + ) + + try: + yield + finally: + set_config(**old_config) diff --git a/skbase/config/tests/__init__.py b/skbase/config/tests/__init__.py new file mode 100644 index 00000000..ac71ad0c --- /dev/null +++ b/skbase/config/tests/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python3 -u +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +"""Test functionality of :mod:`skbase.config`.""" diff --git a/skbase/config/tests/test_config.py b/skbase/config/tests/test_config.py new file mode 100644 index 00000000..67a104ff --- /dev/null +++ b/skbase/config/tests/test_config.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- +"""Test configuration functionality.""" +import pytest + +from skbase.config import ( + _CONFIG_REGISTRY, + _DEFAULT_GLOBAL_CONFIG, + GlobalConfigParamSetting, + config_context, + get_config, + get_config_os_env_names, + get_default_config, + reset_config, + set_config, +) + +PRINT_CHANGE_ONLY_VALUES = _CONFIG_REGISTRY["print_changed_only"].get_allowed_values() +DISPLAY_VALUES = _CONFIG_REGISTRY["display"].get_allowed_values() + + +@pytest.fixture +def config_registry(): + """Config registry fixture.""" + return _CONFIG_REGISTRY + + +@pytest.fixture +def global_config_default(): + """Config registry fixture.""" + return _DEFAULT_GLOBAL_CONFIG + + +@pytest.mark.parametrize("allowed_values", (None, (), "something", range(1, 8))) +def test_global_config_param_get_allowed_values(allowed_values): + """Test GlobalConfigParamSetting behavior works as expected.""" + some_config_param = GlobalConfigParamSetting( + name="some_param", + os_environ_name="SKBASE_OBJECT_DISPLAY", + expected_type=str, + allowed_values=allowed_values, + default_value="text", + ) + # Verify we always coerce output of get_allowed_values to tuple + values = some_config_param.get_allowed_values() + assert isinstance(values, tuple) + + +@pytest.mark.parametrize("value", (None, (), "wrong_string", "text", range(1, 8))) +def test_global_config_param_is_valid_param_value(value): + """Test GlobalConfigParamSetting behavior works as expected.""" + some_config_param = GlobalConfigParamSetting( + name="some_param", + os_environ_name="SKBASE_OBJECT_DISPLAY", + expected_type=str, + allowed_values=("text", "diagram"), + default_value="text", + ) + # Verify we correctly identify invalid parameters + if value in ("text", "diagram"): + expected_valid = True + else: + expected_valid = False + assert some_config_param.is_valid_param_value(value) == expected_valid + + +def test_get_config_os_env_names(config_registry): + """Verify that get_config_os_env_names returns expected values.""" + os_env_names = get_config_os_env_names() + expected_os_env_names = [ + config_info.os_environ_name for config_info in config_registry.values() + ] + msg = "`get_config_os_env_names` does not return expected value.\n" + msg += f"Expected {expected_os_env_names}, but returned {os_env_names}." + assert os_env_names == expected_os_env_names, msg + + +def test_get_default_config(global_config_default): + """Verify that get_default_config returns default global config values.""" + retrieved_default = get_default_config() + msg = "`get_default_config` does not return expected values.\n" + msg += f"Expected {global_config_default}, but returned {retrieved_default}." + assert retrieved_default == global_config_default, msg + + +@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) +@pytest.mark.parametrize("display", DISPLAY_VALUES) +def test_set_config_then_get_config_returns_expected_value(print_changed_only, display): + """Verify that get_config returns set config values if set_config run.""" + set_config(print_changed_only=print_changed_only, display=display) + retrieved_default = get_config() + expected_config = {"print_changed_only": print_changed_only, "display": display} + msg = "`get_config` used after `set_config` does not return expected values.\n" + msg += "After set_config is run, get_config should return the set values.\n " + msg += f"Expected {expected_config}, but returned {retrieved_default}." + assert retrieved_default == expected_config, msg + + +@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) +@pytest.mark.parametrize("display", DISPLAY_VALUES) +def test_reset_config_resets_the_config(print_changed_only, display): + """Verify that get_config returns default config if reset_config run.""" + default_config = get_default_config() + set_config(print_changed_only=print_changed_only, display=display) + reset_config() + retrieved_config = get_config() + + msg = "`get_config` does not return expected values after `reset_config`.\n" + msg += "`After reset_config is run, get_config` should return defaults.\n" + msg += f"Expected {default_config}, but returned {retrieved_config}." + assert retrieved_config == default_config, msg + reset_config() + + +@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) +@pytest.mark.parametrize("display", DISPLAY_VALUES) +def test_config_context(print_changed_only, display): + """Verify that config_context affects context but not overall configuration.""" + # Make sure config is reset to default values then retrieve it + reset_config() + retrieved_config = get_config() + # Now lets make sure the config_context is changing the context of those values + # within the scope of the context manager as expected + for print_changed_only in (True, False): + with config_context(print_changed_only=print_changed_only, display=display): + retrieved_context_config = get_config() + expected_config = {"print_changed_only": print_changed_only, "display": display} + msg = "`get_config` does not return expected values within `config_context`.\n" + msg += "`get_config` should return config defined by `config_context`.\n" + msg += f"Expected {expected_config}, but returned {retrieved_context_config}." + assert retrieved_context_config == expected_config, msg + + # Outside of the config_context we should have not affected the retrieved config + # set by call to reset_config() + config_post_config_context = get_config() + msg = "`get_config` does not return expected values after `config_context`a.\n" + msg += "`config_context` should not affect configuration outside its context.\n" + msg += f"Expected {config_post_config_context}, but returned {retrieved_config}." + assert retrieved_config == config_post_config_context, msg + reset_config() + + +def test_set_config_behavior_invalid_value(): + """Test set_config uses default and raises warning when setting invalid value.""" + reset_config() + with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"): + set_config(print_changed_only="False") + + assert get_config() == get_default_config() + + with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"): + set_config(print_changed_only=7) + + assert get_config() == get_default_config() + reset_config() From 34a9c5acb5e9ae5525d0fe083a155c8622cf8610 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Sun, 12 Mar 2023 23:51:10 -0400 Subject: [PATCH 04/16] Update global config tests --- skbase/config/__init__.py | 40 ------------------------------ skbase/config/tests/test_config.py | 30 +++++++--------------- skbase/tests/conftest.py | 19 +++++++++++++- 3 files changed, 27 insertions(+), 62 deletions(-) diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py index 3f38839b..55d58a3c 100644 --- a/skbase/config/__init__.py +++ b/skbase/config/__init__.py @@ -134,36 +134,6 @@ def _get_threadlocal_config() -> Dict[str, Any]: return _THREAD_LOCAL_DATA.global_config -def get_config_os_env_names() -> List[str]: - """Retrieve the os environment names for configurable settings. - - Returns - ------- - env_names : list - The os environment names that can be used to set configurable settings. - - See Also - -------- - config_context : - Configuration context manager. - get_default_config : - Retrieve ``skbase``'s default configuration. - get_config : - Retrieve current global configuration values. - set_config : - Set global configuration. - reset_config : - Reset configuration to ``skbase`` default. - - Examples - -------- - >>> from skbase.config import get_config_os_env_names - >>> get_config_os_env_names() - ['SKBASE_PRINT_CHANGED_ONLY', 'SKBASE_OBJECT_DISPLAY'] - """ - return [config_info.os_environ_name for config_info in _CONFIG_REGISTRY.values()] - - def get_default_config() -> Dict[str, Any]: """Retrive the default global configuration. @@ -180,8 +150,6 @@ def get_default_config() -> Dict[str, Any]: Configuration context manager. get_config : Retrieve current global configuration values. - get_config_os_env_names : - Retrieve os environment names that can be used to set configuration. set_config : Set global configuration. reset_config : @@ -213,8 +181,6 @@ def get_config() -> Dict[str, Any]: Configuration context manager. get_default_config : Retrieve ``skbase``'s default configuration. - get_config_os_env_names : - Retrieve os environment names that can be used to set configuration. set_config : Set global configuration. reset_config : @@ -267,8 +233,6 @@ def set_config( Retrieve ``skbase``'s default configuration. get_config : Retrieve current global configuration values. - get_config_os_env_names : - Retrieve os environment names that can be used to set configuration. reset_config : Reset configuration to default. @@ -317,8 +281,6 @@ def reset_config() -> None: Retrieve ``skbase``'s default configuration. get_config : Retrieve current global configuration values. - get_config_os_env_names : - Retrieve os environment names that can be used to set configuration. set_config : Set global configuration. @@ -373,8 +335,6 @@ def config_context( Retrieve ``skbase``'s default configuration. get_config : Retrieve current values of the global configuration. - get_config_os_env_names : - Retrieve os environment names that can be used to set configuration. set_config : Set global configuration. reset_config : diff --git a/skbase/config/tests/test_config.py b/skbase/config/tests/test_config.py index 67a104ff..11f3b5cf 100644 --- a/skbase/config/tests/test_config.py +++ b/skbase/config/tests/test_config.py @@ -8,7 +8,6 @@ GlobalConfigParamSetting, config_context, get_config, - get_config_os_env_names, get_default_config, reset_config, set_config, @@ -63,23 +62,11 @@ def test_global_config_param_is_valid_param_value(value): assert some_config_param.is_valid_param_value(value) == expected_valid -def test_get_config_os_env_names(config_registry): - """Verify that get_config_os_env_names returns expected values.""" - os_env_names = get_config_os_env_names() - expected_os_env_names = [ - config_info.os_environ_name for config_info in config_registry.values() - ] - msg = "`get_config_os_env_names` does not return expected value.\n" - msg += f"Expected {expected_os_env_names}, but returned {os_env_names}." - assert os_env_names == expected_os_env_names, msg - - def test_get_default_config(global_config_default): - """Verify that get_default_config returns default global config values.""" - retrieved_default = get_default_config() - msg = "`get_default_config` does not return expected values.\n" - msg += f"Expected {global_config_default}, but returned {retrieved_default}." - assert retrieved_default == global_config_default, msg + """Test get_default_config alwasy returns the default config.""" + assert get_default_config() == global_config_default + set_config(print_changed_only=False) + assert get_default_config() == global_config_default @pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) @@ -97,17 +84,18 @@ def test_set_config_then_get_config_returns_expected_value(print_changed_only, d @pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) @pytest.mark.parametrize("display", DISPLAY_VALUES) -def test_reset_config_resets_the_config(print_changed_only, display): +def test_reset_config_resets_the_config( + print_changed_only, display, global_config_default +): """Verify that get_config returns default config if reset_config run.""" - default_config = get_default_config() set_config(print_changed_only=print_changed_only, display=display) reset_config() retrieved_config = get_config() msg = "`get_config` does not return expected values after `reset_config`.\n" msg += "`After reset_config is run, get_config` should return defaults.\n" - msg += f"Expected {default_config}, but returned {retrieved_config}." - assert retrieved_config == default_config, msg + msg += f"Expected {global_config_default}, but returned {retrieved_config}." + assert retrieved_config == global_config_default, msg reset_config() diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 4f36034a..f4e0e847 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -23,6 +23,7 @@ "skbase.base._base", "skbase.base._meta", "skbase.base._tagmanager", + "skbase.config", "skbase.lookup", "skbase.lookup.tests", "skbase.lookup.tests.test_lookup", @@ -51,6 +52,7 @@ SKBASE_PUBLIC_MODULES = ( "skbase", "skbase.base", + "skbase.config", "skbase.lookup", "skbase.lookup.tests", "skbase.lookup.tests.test_lookup", @@ -74,6 +76,7 @@ "skbase.base": ("BaseEstimator", "BaseMetaEstimator", "BaseObject"), "skbase.base._base": ("BaseEstimator", "BaseObject"), "skbase.base._meta": ("BaseMetaEstimator",), + "skbase.config": ("GlobalConfigParamSetting",), "skbase.lookup._lookup": ("ClassInfo", "FunctionInfo", "ModuleInfo"), "skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"), "skbase.testing.test_all_objects": ( @@ -85,11 +88,17 @@ SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy() SKBASE_CLASSES_BY_MODULE.update( { - "skbase.base._meta": ("BaseMetaEstimator",), "skbase.base._tagmanager": ("_FlagManager",), } ) SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = { + "skbase.config": ( + "get_config", + "get_default_config", + "set_config", + "reset_config", + "config_context", + ), "skbase.lookup": ("all_objects", "get_package_metadata"), "skbase.lookup._lookup": ("all_objects", "get_package_metadata"), "skbase.testing.utils._conditional_fixtures": ( @@ -124,6 +133,14 @@ SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy() SKBASE_FUNCTIONS_BY_MODULE.update( { + "skbase.config": ( + "_get_threadlocal_config", + "get_config", + "get_default_config", + "set_config", + "reset_config", + "config_context", + ), "skbase.lookup._lookup": ( "_determine_module_path", "_get_return_tags", From fc5bee32ff0d2d6481ca343401baeae417fcadde Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 00:24:06 -0400 Subject: [PATCH 05/16] Add updated tests of local configs --- skbase/base/_base.py | 24 ++- skbase/tests/test_base.py | 321 ++++++++++++++++++++++++++++++-------- 2 files changed, 275 insertions(+), 70 deletions(-) diff --git a/skbase/base/_base.py b/skbase/base/_base.py index 970bf350..a7216434 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -63,6 +63,7 @@ class name: BaseEstimator from skbase._exceptions import NotFittedError from skbase.base._tagmanager import _FlagManager +from skbase.config import get_config __author__: List[str] = ["mloning", "RNKuhns", "fkiraly"] __all__: List[str] = ["BaseEstimator", "BaseObject"] @@ -446,7 +447,28 @@ def get_config(self): class attribute via nested inheritance and then any overrides and new tags from _onfig_dynamic object attribute. """ - return self._get_flags(flag_attr_name="_config") + config = get_config().copy() + + # Get any extension configuration interface defined in the class + # for example if downstream package wants to extend skbase to retrieve + # their own config + if hasattr(self, "__skbase_get_config__") and callable( + self.__skbase_get_config__ + ): + skbase_get_config_extension_dict = self.__skbase_get_config__() + else: + skbase_get_config_extension_dict = {} + if isinstance(skbase_get_config_extension_dict, dict): + config.update(skbase_get_config_extension_dict) + else: + msg = "Use of `__skbase_get_config__` to extend the interface for local " + msg += "overrides of the global configuration must return a dictionary.\n" + msg += f"But a {type(skbase_get_config_extension_dict)} was found." + warnings.warn(msg, UserWarning, stacklevel=2) + local_config = self._get_flags(flag_attr_name="_config") + config.update(local_config) + + return config def set_config(self, **config_dict): """Set config flags to given values. diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index 38d2d5a9..66dcf45b 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -69,6 +69,7 @@ import inspect from copy import deepcopy +from typing import Any, Dict, Type import numpy as np import pytest @@ -79,6 +80,7 @@ from sklearn.base import clone from skbase.base import BaseEstimator, BaseObject +from skbase.config import get_config from skbase.tests.conftest import Child, Parent from skbase.tests.mock_package.test_mock_package import CompositionDummy @@ -200,13 +202,13 @@ def fixture_reset_tester(): @pytest.fixture -def fixture_class_child_tags(fixture_class_child): +def fixture_class_child_tags(fixture_class_child: Type[Child]): """Pytest fixture for tags of Child.""" return fixture_class_child.get_class_tags() @pytest.fixture -def fixture_object_instance_set_tags(fixture_tag_class_object): +def fixture_object_instance_set_tags(fixture_tag_class_object: Child): """Fixture class instance to test tag setting.""" fixture_tag_set = {"A": 42424243, "E": 3} return fixture_tag_class_object.set_tags(**fixture_tag_set) @@ -266,7 +268,9 @@ def fixture_class_instance_no_param_interface(): return NoParamInterface() -def test_get_class_tags(fixture_class_child, fixture_class_child_tags): +def test_get_class_tags( + fixture_class_child: Type[Child], fixture_class_child_tags: Any +): """Test get_class_tags class method of BaseObject for correctness. Raises @@ -280,7 +284,7 @@ def test_get_class_tags(fixture_class_child, fixture_class_child_tags): assert child_tags == fixture_class_child_tags, msg -def test_get_class_tag(fixture_class_child, fixture_class_child_tags): +def test_get_class_tag(fixture_class_child: Type[Child], fixture_class_child_tags: Any): """Test get_class_tag class method of BaseObject for correctness. Raises @@ -307,7 +311,7 @@ def test_get_class_tag(fixture_class_child, fixture_class_child_tags): assert child_tag_default_none is None, msg -def test_get_tags(fixture_tag_class_object, fixture_object_tags): +def test_get_tags(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]): """Test get_tags method of BaseObject for correctness. Raises @@ -321,7 +325,7 @@ def test_get_tags(fixture_tag_class_object, fixture_object_tags): assert object_tags == fixture_object_tags, msg -def test_get_tag(fixture_tag_class_object, fixture_object_tags): +def test_get_tag(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]): """Test get_tag method of BaseObject for correctness. Raises @@ -351,7 +355,7 @@ def test_get_tag(fixture_tag_class_object, fixture_object_tags): assert object_tag_default_none is None, msg -def test_get_tag_raises(fixture_tag_class_object): +def test_get_tag_raises(fixture_tag_class_object: Child): """Test that get_tag method raises error for unknown tag. Raises @@ -363,9 +367,9 @@ def test_get_tag_raises(fixture_tag_class_object): def test_set_tags( - fixture_object_instance_set_tags, - fixture_object_set_tags, - fixture_object_dynamic_tags, + fixture_object_instance_set_tags: Any, + fixture_object_set_tags: Dict[str, Any], + fixture_object_dynamic_tags: Dict[str, int], ): """Test set_tags method of BaseObject for correctness. @@ -381,7 +385,9 @@ def test_set_tags( assert fixture_object_instance_set_tags.get_tags() == fixture_object_set_tags, msg -def test_set_tags_works_with_missing_tags_dynamic_attribute(fixture_tag_class_object): +def test_set_tags_works_with_missing_tags_dynamic_attribute( + fixture_tag_class_object: Child, +): """Test set_tags will still work if _tags_dynamic is missing.""" base_obj = deepcopy(fixture_tag_class_object) delattr(base_obj, "_tags_dynamic") @@ -460,7 +466,7 @@ class AnotherTestClass(BaseObject): assert test_obj_tags.get(tag) == another_base_obj_tags[tag] -def test_is_composite(fixture_composition_dummy): +def test_is_composite(fixture_composition_dummy: Type[CompositionDummy]): """Test is_composite tag for correctness. Raises @@ -474,7 +480,11 @@ def test_is_composite(fixture_composition_dummy): assert composite.is_composite() -def test_components(fixture_object, fixture_class_parent, fixture_composition_dummy): +def test_components( + fixture_object: Type[BaseObject], + fixture_class_parent: Type[Parent], + fixture_composition_dummy: Type[CompositionDummy], +): """Test component retrieval. Raises @@ -507,7 +517,7 @@ def test_components(fixture_object, fixture_class_parent, fixture_composition_du def test_components_raises_error_base_class_is_not_class( - fixture_object, fixture_composition_dummy + fixture_object: Type[BaseObject], fixture_composition_dummy: Type[CompositionDummy] ): """Test _component method raises error if base_class param is not class.""" non_composite = fixture_composition_dummy(foo=42) @@ -526,7 +536,7 @@ def test_components_raises_error_base_class_is_not_class( def test_components_raises_error_base_class_is_not_baseobject_subclass( - fixture_composition_dummy, + fixture_composition_dummy: Type[CompositionDummy], ): """Test _component method raises error if base_class is not BaseObject subclass.""" @@ -540,7 +550,7 @@ class SomeClass: # Test parameter interface (get_params, set_params, reset and related methods) # Some tests of get_params and set_params are adapted from sklearn tests -def test_reset(fixture_reset_tester): +def test_reset(fixture_reset_tester: Type[ResetTester]): """Test reset method for correct behaviour, on a simple estimator. Raises @@ -567,7 +577,7 @@ def test_reset(fixture_reset_tester): assert hasattr(x, "foo") -def test_reset_composite(fixture_reset_tester): +def test_reset_composite(fixture_reset_tester: Type[ResetTester]): """Test reset method for correct behaviour, on a composite estimator.""" y = fixture_reset_tester(42) x = fixture_reset_tester(a=y) @@ -582,7 +592,7 @@ def test_reset_composite(fixture_reset_tester): assert not hasattr(x.a, "d") -def test_get_init_signature(fixture_class_parent): +def test_get_init_signature(fixture_class_parent: Type[Parent]): """Test error is raised when invalid init signature is used.""" init_sig = fixture_class_parent._get_init_signature() init_sig_is_list = isinstance(init_sig, list) @@ -594,14 +604,18 @@ def test_get_init_signature(fixture_class_parent): ), "`_get_init_signature` is not returning expected result." -def test_get_init_signature_raises_error_for_invalid_signature(fixture_invalid_init): +def test_get_init_signature_raises_error_for_invalid_signature( + fixture_invalid_init: Type[InvalidInitSignatureTester], +): """Test error is raised when invalid init signature is used.""" with pytest.raises(RuntimeError): fixture_invalid_init._get_init_signature() def test_get_param_names( - fixture_object, fixture_class_parent, fixture_class_parent_expected_params + fixture_object: Type[BaseObject], + fixture_class_parent: Type[Parent], + fixture_class_parent_expected_params: Dict[str, Any], ): """Test that get_param_names returns list of string parameter names.""" param_names = fixture_class_parent.get_param_names() @@ -612,10 +626,10 @@ def test_get_param_names( def test_get_params( - fixture_class_parent, - fixture_class_parent_expected_params, - fixture_class_instance_no_param_interface, - fixture_composition_dummy, + fixture_class_parent: Type[Parent], + fixture_class_parent_expected_params: Dict[str, Any], + fixture_class_instance_no_param_interface: NoParamInterface, + fixture_composition_dummy: Type[CompositionDummy], ): """Test get_params returns expected parameters.""" # Simple test of returned params @@ -638,7 +652,10 @@ def test_get_params( assert "foo" in params and "bar" in params and len(params) == 2 -def test_get_params_invariance(fixture_class_parent, fixture_composition_dummy): +def test_get_params_invariance( + fixture_class_parent: Type[Parent], + fixture_composition_dummy: Type[CompositionDummy], +): """Test that get_params(deep=False) is subset of get_params(deep=True).""" composite = fixture_composition_dummy(foo=fixture_class_parent(), bar=84) shallow_params = composite.get_params(deep=False) @@ -646,7 +663,7 @@ def test_get_params_invariance(fixture_class_parent, fixture_composition_dummy): assert all(item in deep_params.items() for item in shallow_params.items()) -def test_get_params_after_set_params(fixture_class_parent): +def test_get_params_after_set_params(fixture_class_parent: Type[Parent]): """Test that get_params returns the same thing before and after set_params. Based on scikit-learn check in check_estimator. @@ -687,9 +704,9 @@ def test_get_params_after_set_params(fixture_class_parent): def test_set_params( - fixture_class_parent, - fixture_class_parent_expected_params, - fixture_composition_dummy, + fixture_class_parent: Type[Parent], + fixture_class_parent_expected_params: Dict[str, Any], + fixture_composition_dummy: Type[CompositionDummy], ): """Test set_params works as expected.""" # Simple case of setting a parameter @@ -711,7 +728,8 @@ def test_set_params( def test_set_params_raises_error_non_existent_param( - fixture_class_parent_instance, fixture_composition_dummy + fixture_class_parent_instance: Parent, + fixture_composition_dummy: Type[CompositionDummy], ): """Test set_params raises an error when passed a non-existent parameter name.""" # non-existing parameter in svc @@ -727,7 +745,8 @@ def test_set_params_raises_error_non_existent_param( def test_set_params_raises_error_non_interface_composite( - fixture_class_instance_no_param_interface, fixture_composition_dummy + fixture_class_instance_no_param_interface: NoParamInterface, + fixture_composition_dummy: Type[CompositionDummy], ): """Test set_params raises error when setting param of non-conforming composite.""" # When a composite is made up of a class that doesn't have the BaseObject @@ -753,7 +772,9 @@ def __init__(self, param=5): est.get_params() -def test_set_params_with_no_param_to_set_returns_object(fixture_class_parent): +def test_set_params_with_no_param_to_set_returns_object( + fixture_class_parent: Type[Parent], +): """Test set_params correctly returns self when no parameters are set.""" base_obj = fixture_class_parent() orig_params = deepcopy(base_obj.get_params()) @@ -767,7 +788,7 @@ def test_set_params_with_no_param_to_set_returns_object(fixture_class_parent): # This section tests the clone functionality # These have been adapted from sklearn's tests of clone to use the clone # method that is included as part of the BaseObject interface -def test_clone(fixture_class_parent_instance): +def test_clone(fixture_class_parent_instance: Parent): """Test that clone is making a deep copy as expected.""" # Creates a BaseObject and makes a copy of its original state # (which, in this case, is the current state of the BaseObject), @@ -777,7 +798,7 @@ def test_clone(fixture_class_parent_instance): assert fixture_class_parent_instance.get_params() == new_base_obj.get_params() -def test_clone_2(fixture_class_parent_instance): +def test_clone_2(fixture_class_parent_instance: Parent): """Test that clone does not copy attributes not set in constructor.""" # We first create an estimator, give it an own attribute, and # make a copy of its original state. Then we check that the copy doesn't @@ -790,7 +811,9 @@ def test_clone_2(fixture_class_parent_instance): def test_clone_raises_error_for_nonconforming_objects( - fixture_invalid_init, fixture_buggy, fixture_modify_param + fixture_invalid_init: Type[InvalidInitSignatureTester], + fixture_buggy: Type[Buggy], + fixture_modify_param: Type[ModifyParam], ): """Test that clone raises an error on nonconforming BaseObjects.""" buggy = fixture_buggy() @@ -807,7 +830,7 @@ def test_clone_raises_error_for_nonconforming_objects( obj_that_modifies.clone() -def test_clone_param_is_none(fixture_class_parent): +def test_clone_param_is_none(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter set to None.""" base_obj = fixture_class_parent(c=None) new_base_obj = clone(base_obj) @@ -816,7 +839,7 @@ def test_clone_param_is_none(fixture_class_parent): assert base_obj.c is new_base_obj2.c -def test_clone_empty_array(fixture_class_parent): +def test_clone_empty_array(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter is scipy sparse matrix. This test is based on scikit-learn regression test to make sure clone @@ -830,7 +853,7 @@ def test_clone_empty_array(fixture_class_parent): np.testing.assert_array_equal(base_obj.c, new_base_obj2.c) -def test_clone_sparse_matrix(fixture_class_parent): +def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter is scipy sparse matrix. This test is based on scikit-learn regression test to make sure clone @@ -843,7 +866,7 @@ def test_clone_sparse_matrix(fixture_class_parent): np.testing.assert_array_equal(base_obj.c, new_base_obj2.c) -def test_clone_nan(fixture_class_parent): +def test_clone_nan(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter is np.nan. This test is based on scikit-learn regression test to make sure clone @@ -858,7 +881,7 @@ def test_clone_nan(fixture_class_parent): assert base_obj.c is new_base_obj2.c -def test_clone_estimator_types(fixture_class_parent): +def test_clone_estimator_types(fixture_class_parent: Type[Parent]): """Test clone works for parameters that are types rather than instances.""" base_obj = fixture_class_parent(c=fixture_class_parent) new_base_obj = base_obj.clone() @@ -866,7 +889,9 @@ def test_clone_estimator_types(fixture_class_parent): assert base_obj.c == new_base_obj.c -def test_clone_class_rather_than_instance_raises_error(fixture_class_parent): +def test_clone_class_rather_than_instance_raises_error( + fixture_class_parent: Type[Parent], +): """Test clone raises expected error when cloning a class instead of an instance.""" msg = "You should provide an instance of scikit-learn estimator" with pytest.raises(TypeError, match=msg): @@ -874,7 +899,10 @@ def test_clone_class_rather_than_instance_raises_error(fixture_class_parent): # Tests of BaseObject pretty printing representation inspired by sklearn -def test_baseobject_repr(fixture_class_parent, fixture_composition_dummy): +def test_baseobject_repr( + fixture_class_parent: Type[Parent], + fixture_composition_dummy: Type[CompositionDummy], +): """Test BaseObject repr works as expected.""" # Simple test where all parameters are left at defaults # Should not see parameters and values in printed representation @@ -893,12 +921,12 @@ def test_baseobject_repr(fixture_class_parent, fixture_composition_dummy): assert len(repr(long_base_obj_repr)) == 535 -def test_baseobject_str(fixture_class_parent_instance): +def test_baseobject_str(fixture_class_parent_instance: Parent): """Test BaseObject string representation works.""" str(fixture_class_parent_instance) -def test_baseobject_repr_mimebundle_(fixture_class_parent_instance): +def test_baseobject_repr_mimebundle_(fixture_class_parent_instance: Parent): """Test display configuration controls output.""" # Checks the display configuration flag controls the json output with config_context(display="diagram"): @@ -912,7 +940,7 @@ def test_baseobject_repr_mimebundle_(fixture_class_parent_instance): assert "text/html" not in output -def test_repr_html_wraps(fixture_class_parent_instance): +def test_repr_html_wraps(fixture_class_parent_instance: Parent): """Test display configuration flag controls the html output.""" with config_context(display="diagram"): output = fixture_class_parent_instance._repr_html_() @@ -925,20 +953,24 @@ def test_repr_html_wraps(fixture_class_parent_instance): # Test BaseObject's ability to generate test instances -def test_get_test_params(fixture_class_parent_instance): +def test_get_test_params(fixture_class_parent_instance: Parent): """Test get_test_params returns empty dictionary.""" base_obj = fixture_class_parent_instance test_params = base_obj.get_test_params() assert isinstance(test_params, dict) and len(test_params) == 0 -def test_get_test_params_raises_error_when_params_required(fixture_required_param): +def test_get_test_params_raises_error_when_params_required( + fixture_required_param: Type[RequiredParam], +): """Test get_test_params raises an error when parameters are required.""" with pytest.raises(ValueError): fixture_required_param(7).get_test_params() -def test_create_test_instance(fixture_class_parent, fixture_class_parent_instance): +def test_create_test_instance( + fixture_class_parent: Type[Parent], fixture_class_parent_instance: Parent +): """Test first that create_test_instance logic works.""" base_obj = fixture_class_parent.create_test_instance() @@ -957,7 +989,7 @@ def test_create_test_instance(fixture_class_parent, fixture_class_parent_instanc assert hasattr(base_obj, "_tags_dynamic"), msg -def test_create_test_instances_and_names(fixture_class_parent_instance): +def test_create_test_instances_and_names(fixture_class_parent_instance: Parent): """Test that create_test_instances_and_names works.""" base_objs, names = fixture_class_parent_instance.create_test_instances_and_names() @@ -990,7 +1022,7 @@ def test_create_test_instances_and_names(fixture_class_parent_instance): # Tests _has_implementation_of interface def test_has_implementation_of( - fixture_class_parent_instance, fixture_class_child_instance + fixture_class_parent_instance: Parent, fixture_class_child_instance: Child ): """Test _has_implementation_of detects methods in class with overrides in mro.""" # When the class overrides a parent classes method should return True @@ -1014,33 +1046,184 @@ def __init__(self, a, b=42): self.c = 84 -def test_set_get_config(): - """Test logic behind get_config, set_config. +class AnotherConfigTester(BaseObject): + _config = {"print_changed_only": False, "bar": "a"} - Raises - ------ - AssertionError if logic behind get_config, set_config is incorrect, logic tested: - calling get_fitted_params on a non-composite fittable returns the fitted param - calling get_fitted_params on a composite returns all nested params + clsvar = 210 + + def __init__(self, a, b=42): + self.a = a + self.b = b + self.c = 84 + + +class ConfigExtensionInterfaceTester(BaseObject): + _config = {"print_changed_only": False, "bar": "a"} + + clsvar = 210 + + def __init__(self, a, b=42): + self.a = a + self.b = b + self.c = 84 + + def __skbase_get_config__(self): + """Return get_config extension.""" + return {"print_changed_only": True, "some_other_config": 70} + + +def test_local_config_without_use_of_extension_interface(): + """Test ``BaseObject().get_config`` and ``BaseObject().set_config``. + + ``BaseObject.get_config()`` should return the global dict updated with any local + configs defined in ``BaseObject._config`` or set on the instance with + ``BaseObject().set_config()``. """ - obj = ConfigTester(4242) + # Initially test that We can retrieve the local config with local configs + # as defined in BaseObject._config - config_start = obj.get_config() - assert isinstance(config_start, dict) - assert set(config_start.keys()) == {"foo_config", "bar"} - assert config_start["foo_config"] == 42 - assert config_start["bar"] == "a" + # Case 1: local configs are not part of global config param names + obj = ConfigTester(4242) + obj_config = obj.get_config() + current_global_config = get_config().copy() + expected_global_config = set(current_global_config.keys()) + assert isinstance(obj_config, dict) + assert set(obj_config.keys()) == expected_global_config | {"foo_config", "bar"} + assert obj_config["foo_config"] == 42 + assert obj_config["bar"] == "a" + for param_name in current_global_config: + assert obj_config[param_name] == current_global_config[param_name] + + # Case 2: local configs overlap with (will override) global params + obj = AnotherConfigTester(4242) + obj_config = obj.get_config() + current_global_config = get_config().copy() + expected_global_config = set(current_global_config.keys()) + assert isinstance(obj_config, dict) + assert set(obj_config.keys()) == expected_global_config | {"bar"} + assert obj_config["bar"] == "a" + # Should have overrided global config value which is set to True + assert obj_config["print_changed_only"] is False + for param_name in current_global_config: + if param_name != "print_changed_only": + assert obj_config[param_name] == current_global_config[param_name] + + # Case 3: local configs are not part of global config param names and we also + # make use of dynamic BaseObject.set_config() + obj = ConfigTester(4242) + obj_config = obj.get_config() + current_global_config = get_config().copy() + expected_global_config = set(current_global_config.keys()) + # Verify set config returns the original object setconfig_return = obj.set_config(foobar=126) assert obj is setconfig_return obj.set_config(**{"bar": "b"}) - config_end = obj.get_config() - assert isinstance(config_end, dict) - assert set(config_end.keys()) == {"foo_config", "bar", "foobar"} - assert config_end["foo_config"] == 42 - assert config_end["bar"] == "b" - assert config_end["foobar"] == 126 + updated_obj_config = obj.get_config() + assert isinstance(updated_obj_config, dict) + assert set(updated_obj_config.keys()) == ( + expected_global_config | {"foo_config", "bar", "foobar"} + ) + assert updated_obj_config["foo_config"] == 42 + assert updated_obj_config["bar"] == "b" + assert updated_obj_config["foobar"] == 126 + + # Case 4: local configs are not part of global config param names and we also + # make use of dynamic BaseObject.set_config() to update a config that is also + # part of global config + obj = ConfigTester(4242) + obj_config = obj.get_config() + current_global_config = get_config().copy() + expected_global_config = set(current_global_config.keys()) + + # Verify set config returns the original object + setconfig_return = obj.set_config(print_changed_only=False) + assert obj is setconfig_return + + updated_obj_config = obj.get_config() + assert isinstance(updated_obj_config, dict) + assert set(updated_obj_config.keys()) == ( + expected_global_config | {"foo_config", "bar"} + ) + assert updated_obj_config["foo_config"] == 42 + assert updated_obj_config["bar"] == "a" + assert updated_obj_config["print_changed_only"] is False + for param_name in current_global_config: + if param_name != "print_changed_only": + assert updated_obj_config[param_name] == current_global_config[param_name] + + # Case 5: local configs overlap with (will override) global params + # Then the local config defined in AnotherConfigTester._config is overrode again + # by calling AnotherConfigTester().set_config() + obj = AnotherConfigTester(4242) + obj.set_config(print_changed_only=True) + obj_config = obj.get_config() + current_global_config = get_config().copy() + expected_global_config = set(current_global_config.keys()) + assert isinstance(obj_config, dict) + assert set(obj_config.keys()) == expected_global_config | {"bar"} + assert obj_config["bar"] == "a" + # Should have overrided global config value which is set to True + assert obj_config["print_changed_only"] is True + for param_name in current_global_config: + if param_name != "print_changed_only": + assert obj_config[param_name] == current_global_config[param_name] + + +def test_local_config_with_use_of_extension_interface(): + """Test BaseObject local config interface when ``__skbase_get_config__`` defined. + + BaseObject.get_config() should return the global dict updated in the following + order: + + - Any config returned by ``BaseObject.__skbase_get_config__``. + - Any configs defined in ``BaseObject._config`` or set on the instance with + ``BaseObject().set_config()``. + """ + current_global_config = get_config().copy() + obj = ConfigExtensionInterfaceTester(4242) + obj_config = obj.get_config() + assert "some_other_config" in obj_config + assert obj_config["print_changed_only"] is False + + expected_global_config = set(current_global_config.keys()) + assert isinstance(obj_config, dict) + assert set(obj_config.keys()) == expected_global_config | { + "some_other_config", + "bar", + } + assert obj_config["bar"] == "a" + # Should have overrided global config value which is set to True + + for param_name in current_global_config: + if param_name != "print_changed_only": + assert obj_config[param_name] == current_global_config[param_name] + + # Now lets verify we can override the config items only returned by + # __skbase_get_config__ extension interface + obj.set_config(some_other_config=22) + obj_config_updated = obj.get_config() + assert obj_config_updated["some_other_config"] == 22 + for param_name in obj_config: + if param_name != "some_other_config": + assert obj_config_updated[param_name] == obj_config[param_name] + + +def local_config_interface_does_not_affect_global_config_interface(): + """Test that calls to instance config interface doesn't impact global config.""" + from skbase.config import get_default_config, global_config + + obj = AnotherConfigTester(4242) + _global_config_before = global_config.copy() + _default_config_before = get_default_config() + global_config_before = get_config().copy() + obj.set_config(print_changed_only=False, some_other_param=7) + global_config_after = get_config().copy() + assert global_config_before == global_config_after + assert "some_other_param" not in global_config_after + assert _global_config_before == global_config + assert _default_config_before == get_default_config() class FittableCompositionDummy(BaseEstimator): From d89e7caec80a59db40a5dae657b30b351f3d0bc3 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 00:54:21 -0400 Subject: [PATCH 06/16] Add global config to docs --- docs/source/api_reference.rst | 21 +++++++++++++++++++ docs/source/user_documentation/user_guide.rst | 15 +++++++++++++ .../user_guide/configuration.rst | 11 ++++++++++ skbase/config/__init__.py | 11 ++++++++-- 4 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 docs/source/user_documentation/user_guide/configuration.rst diff --git a/docs/source/api_reference.rst b/docs/source/api_reference.rst index 03026bdf..9508d39e 100644 --- a/docs/source/api_reference.rst +++ b/docs/source/api_reference.rst @@ -28,6 +28,27 @@ Base Classes BaseObject BaseEstimator +.. _global_config: + +Configure ``skbase`` +==================== + +.. automodule:: skbase.config + :no-members: + :no-inherited-members: + +.. currentmodule:: skbase.config + +.. autosummary:: + :toctree: api_reference/auto_generated/ + :template: function.rst + + get_config + get_default_config + set_config + reset_config + config_context + .. _obj_retrieval: Object Retrieval diff --git a/docs/source/user_documentation/user_guide.rst b/docs/source/user_documentation/user_guide.rst index c8febd31..d571554b 100644 --- a/docs/source/user_documentation/user_guide.rst +++ b/docs/source/user_documentation/user_guide.rst @@ -29,6 +29,7 @@ that ``skbase`` provides, see the :ref:`api_ref`. user_guide/lookup user_guide/validate user_guide/testing + user_guide/configuration .. grid:: 1 2 2 2 @@ -103,3 +104,17 @@ that ``skbase`` provides, see the :ref:`api_ref`. :expand: Testing + + .. grid-item-card:: Configuration + :text-align: center + + Configure ``skbase``. + + +++ + + .. button-ref:: user_guide/testing + :color: primary + :click-parent: + :expand: + + Configuration diff --git a/docs/source/user_documentation/user_guide/configuration.rst b/docs/source/user_documentation/user_guide/configuration.rst new file mode 100644 index 00000000..05b76843 --- /dev/null +++ b/docs/source/user_documentation/user_guide/configuration.rst @@ -0,0 +1,11 @@ +.. _user_guide_global_config: + +==================== +Configure ``skbase`` +==================== + +.. note:: + + The user guide is under development. We have created a basic + structure and are looking for contributions to develop the user guide + further. diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py index 55d58a3c..b8c8d3c0 100644 --- a/skbase/config/__init__.py +++ b/skbase/config/__init__.py @@ -1,5 +1,9 @@ # -*- coding: utf-8 -*- -""":mod:`skbase.config` provides tools for global configuration of ``skbase``.""" +""":mod:`skbase.config` provides tools for the global configuration of ``skbase``. + +For more information on configuration usage patterns see the +:ref:`user guide `. +""" # -*- coding: utf-8 -*- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file) # Includes functionality like get_config, set_config, and config_context @@ -167,7 +171,7 @@ def get_default_config() -> Dict[str, Any]: def get_config() -> Dict[str, Any]: """Retrieve current values for configuration set by :meth:`set_config`. - Willr return the default configuration if know updated configuration has + Will return the default configuration if know updated configuration has been set by :meth:`set_config`. Returns @@ -309,6 +313,9 @@ def config_context( ) -> Iterator[None]: """Context manager for global configuration. + Provides the ability to run code using different configuration without + having to update the global config. + Parameters ---------- print_changed_only : bool, default=None From 68560d36dc47d5cbc7ac2a17dd37f839d7e2f56e Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 00:59:20 -0400 Subject: [PATCH 07/16] Fix user_guide panel for configuration --- docs/source/user_documentation/user_guide.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user_documentation/user_guide.rst b/docs/source/user_documentation/user_guide.rst index d571554b..4e15f3ec 100644 --- a/docs/source/user_documentation/user_guide.rst +++ b/docs/source/user_documentation/user_guide.rst @@ -112,7 +112,7 @@ that ``skbase`` provides, see the :ref:`api_ref`. +++ - .. button-ref:: user_guide/testing + .. button-ref:: user_guide/configuration :color: primary :click-parent: :expand: From ddaa90e03d0bd7fd424817527ea2b8ccc065867e Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 01:01:24 -0400 Subject: [PATCH 08/16] Fix docstring typo --- skbase/config/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py index b8c8d3c0..a7ca56b8 100644 --- a/skbase/config/__init__.py +++ b/skbase/config/__init__.py @@ -139,7 +139,7 @@ def _get_threadlocal_config() -> Dict[str, Any]: def get_default_config() -> Dict[str, Any]: - """Retrive the default global configuration. + """Retrieve the default global configuration. This will always return the default ``skbase`` global configuration. From b8f62d1e7d0b81d6e01ec0b0c60803c7c6762374 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 01:23:03 -0400 Subject: [PATCH 09/16] Fix pre-commit errors --- skbase/base/_pretty_printing/_object_html_repr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/skbase/base/_pretty_printing/_object_html_repr.py b/skbase/base/_pretty_printing/_object_html_repr.py index 650a2072..851061c0 100644 --- a/skbase/base/_pretty_printing/_object_html_repr.py +++ b/skbase/base/_pretty_printing/_object_html_repr.py @@ -70,7 +70,7 @@ def _write_label_html( checked=False, ): """Write labeled html with or without a dropdown with named details.""" - out.write(f'
') + out.write(f'
') name = html.escape(name) if name_details is not None: @@ -81,8 +81,8 @@ def _write_label_html( est_id = uuid.uuid4() out.write( '' - f'' + f'id={est_id!r} type="checkbox" {checked_str}>' + f"" f'
{name_details}'
             "
" ) @@ -379,7 +379,7 @@ def _object_html_repr(base_object): ) out.write( f"" - f'
' + f'
' '
' f"
{html.escape(base_object_str)}
{fallback_msg}" "
" From b20283bde83522f7d03711d28b1c150907a57714 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 01:25:47 -0400 Subject: [PATCH 10/16] Add missing __init__ --- skbase/base/_pretty_printing/__init__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 skbase/base/_pretty_printing/__init__.py diff --git a/skbase/base/_pretty_printing/__init__.py b/skbase/base/_pretty_printing/__init__.py new file mode 100644 index 00000000..669c800c --- /dev/null +++ b/skbase/base/_pretty_printing/__init__.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 -u +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Many elements of this code were developed in scikit-learn. These elements +# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For +# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +"""Functionality for pretty printing BaseObjects.""" +from typing import List + +__author__: List[str] = ["RNKuhns"] +__all__: List[str] = [] From eeee466a7bb52ae725c3f27a2cf965cab19715dc Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 02:01:51 -0400 Subject: [PATCH 11/16] Tweak behavior when invalid values encountered --- skbase/base/_base.py | 15 +++++++-- skbase/config/__init__.py | 52 +++++++++++++++++++++--------- skbase/config/tests/test_config.py | 8 +++-- skbase/tests/test_base.py | 1 - 4 files changed, 55 insertions(+), 21 deletions(-) diff --git a/skbase/base/_base.py b/skbase/base/_base.py index a7216434..ae10cf29 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -63,7 +63,7 @@ class name: BaseEstimator from skbase._exceptions import NotFittedError from skbase.base._tagmanager import _FlagManager -from skbase.config import get_config +from skbase.config import _CONFIG_REGISTRY, get_config __author__: List[str] = ["mloning", "RNKuhns", "fkiraly"] __all__: List[str] = ["BaseEstimator", "BaseObject"] @@ -465,7 +465,18 @@ class attribute via nested inheritance and then any overrides msg += "overrides of the global configuration must return a dictionary.\n" msg += f"But a {type(skbase_get_config_extension_dict)} was found." warnings.warn(msg, UserWarning, stacklevel=2) - local_config = self._get_flags(flag_attr_name="_config") + local_config = self._get_flags(flag_attr_name="_config").copy() + # IF the local config is one of + for config_param, config_value in local_config.items(): + if config_param in _CONFIG_REGISTRY: + msg = "Invalid value encountered for global configuration parameter " + msg += f"{config_param}. Using global parameter configuration value.\n" + config_value = _CONFIG_REGISTRY[ + config_param + ].get_valid_param_or_default( + config_value, default_value=config[config_param] + ) + local_config[config_param] = config_value config.update(local_config) return config diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py index a7ca56b8..2d1abfde 100644 --- a/skbase/config/__init__.py +++ b/skbase/config/__init__.py @@ -35,10 +35,10 @@ class GlobalConfigParamSetting: name: str os_environ_name: str expected_type: Union[type, Tuple[type]] - allowed_values: Optional[Tuple[Any, ...]] + allowed_values: Optional[Union[Tuple[Any, ...], List[Any]]] default_value: Any - def get_allowed_values(self): + def get_allowed_values(self) -> List[Any]: """Get `allowed_values` or empty tuple if `allowed_values` is None. Returns @@ -47,15 +47,15 @@ def get_allowed_values(self): Allowable values if any. """ if self.allowed_values is None: - return () - elif isinstance(self.allowed_values, tuple): + return [] + elif isinstance(self.allowed_values, list): return self.allowed_values elif isinstance( self.allowed_values, collections.abc.Iterable ) and not isinstance(self.allowed_values, str): - return tuple(self.allowed_values) + return list(self.allowed_values) else: - return (self.allowed_values,) + return [self.allowed_values] def is_valid_param_value(self, value): """Validate that a global configuration value is valid. @@ -79,22 +79,36 @@ def is_valid_param_value(self, value): valid_param = True return valid_param - def get_valid_param_or_default(self, value): - """Validate `value` and return default if it is not valid.""" + def get_valid_param_or_default(self, value, default_value=None, msg=None): + """Validate `value` and return default if it is not valid. + + Parameters + ---------- + value : Any + The configuration parameter value to set. + default_value : Any, default=None + An optional default value to use to set the configuration parameter + if `value` is not valid based on defined expected type and allowed + values. If None, and `value` is invalid then the classes `default_value` + parameter is used. + msg : str, default=None + An optional message to be used as start of the UserWarning message. + """ if self.is_valid_param_value(value): return value else: - msg = "Attempting to set an invalid value for a global configuration.\n" - msg += "Using default configuration value of parameter as a result.\n" + if msg is None: + msg = "" msg + f"When setting global config values for `{self.name}`, the values " - msg += f"should be of type {self.expected_type}." + msg += f"should be of type {self.expected_type}.\n" if self.allowed_values is not None: values_str = _format_seq_to_str( self.get_allowed_values(), last_sep="or", remove_type_text=True ) - msg += f"Allowed values are be one of {values_str}." + msg += f"Allowed values are be one of {values_str}. " + msg += f"But found {value}." warnings.warn(msg, UserWarning, stacklevel=2) - return self.default_value + return default_value if default_value is not None else self.default_value _CONFIG_REGISTRY: Dict[str, GlobalConfigParamSetting] = { @@ -251,14 +265,22 @@ def set_config( """ local_config = _get_threadlocal_config() + msg = "Attempting to set an invalid value for a global configuration.\n" + msg += "Using current configuration value of parameter as a result.\n" if print_changed_only is not None: local_config["print_changed_only"] = _CONFIG_REGISTRY[ "print_changed_only" - ].get_valid_param_or_default(print_changed_only) + ].get_valid_param_or_default( + print_changed_only, + default_value=local_config["print_changed_only"], + msg=msg, + ) if display is not None: local_config["display"] = _CONFIG_REGISTRY[ "display" - ].get_valid_param_or_default(display) + ].get_valid_param_or_default( + display, default_value=local_config["display"], msg=msg + ) if not local_threadsafe: global_config.update(local_config) diff --git a/skbase/config/tests/test_config.py b/skbase/config/tests/test_config.py index 11f3b5cf..0569125e 100644 --- a/skbase/config/tests/test_config.py +++ b/skbase/config/tests/test_config.py @@ -41,7 +41,7 @@ def test_global_config_param_get_allowed_values(allowed_values): ) # Verify we always coerce output of get_allowed_values to tuple values = some_config_param.get_allowed_values() - assert isinstance(values, tuple) + assert isinstance(values, list) @pytest.mark.parametrize("value", (None, (), "wrong_string", "text", range(1, 8))) @@ -130,13 +130,15 @@ def test_config_context(print_changed_only, display): def test_set_config_behavior_invalid_value(): """Test set_config uses default and raises warning when setting invalid value.""" reset_config() + original_config = get_config().copy() with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"): set_config(print_changed_only="False") - assert get_config() == get_default_config() + assert get_config() == original_config + original_config = get_config().copy() with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"): set_config(print_changed_only=7) - assert get_config() == get_default_config() + assert get_config() == original_config reset_config() diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index 66dcf45b..f934b2c2 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -1133,7 +1133,6 @@ def test_local_config_without_use_of_extension_interface(): # make use of dynamic BaseObject.set_config() to update a config that is also # part of global config obj = ConfigTester(4242) - obj_config = obj.get_config() current_global_config = get_config().copy() expected_global_config = set(current_global_config.keys()) From 91a72902888e035ce7de6eac049df071ea2eebfa Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 02:08:52 -0400 Subject: [PATCH 12/16] Remove unused code line --- skbase/tests/test_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index f934b2c2..27ca97bf 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -1111,7 +1111,6 @@ def test_local_config_without_use_of_extension_interface(): # Case 3: local configs are not part of global config param names and we also # make use of dynamic BaseObject.set_config() obj = ConfigTester(4242) - obj_config = obj.get_config() current_global_config = get_config().copy() expected_global_config = set(current_global_config.keys()) From 289d0ab3604517bb345055d151c997ef668393c9 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 02:23:04 -0400 Subject: [PATCH 13/16] Move config implementation out of __init__ --- skbase/base/_base.py | 3 +- skbase/config/__init__.py | 395 ++-------------------------- skbase/config/_config.py | 397 +++++++++++++++++++++++++++++ skbase/config/tests/test_config.py | 3 +- skbase/tests/conftest.py | 11 +- 5 files changed, 427 insertions(+), 382 deletions(-) create mode 100644 skbase/config/_config.py diff --git a/skbase/base/_base.py b/skbase/base/_base.py index ae10cf29..455782a1 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -63,7 +63,8 @@ class name: BaseEstimator from skbase._exceptions import NotFittedError from skbase.base._tagmanager import _FlagManager -from skbase.config import _CONFIG_REGISTRY, get_config +from skbase.config import get_config +from skbase.config._config import _CONFIG_REGISTRY __author__: List[str] = ["mloning", "RNKuhns", "fkiraly"] __all__: List[str] = ["BaseEstimator", "BaseObject"] diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py index 2d1abfde..fd8a750a 100644 --- a/skbase/config/__init__.py +++ b/skbase/config/__init__.py @@ -10,384 +10,23 @@ # that is similar to scikit-learn. These elements are copyrighted by the # scikit-learn developers, BSD-3-Clause License. For conditions see # https://github.com/scikit-learn/scikit-learn/blob/main/COPYING -import collections -import sys -import threading -import warnings -from contextlib import contextmanager -from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import List -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal # type: ignore - -from skbase.utils._iter import _format_seq_to_str +from skbase.config._config import ( + GlobalConfigParamSetting, + config_context, + get_config, + get_default_config, + reset_config, + set_config, +) __author__: List[str] = ["RNKuhns"] - - -@dataclass -class GlobalConfigParamSetting: - """Define types of the setting information for a given config parameter.""" - - name: str - os_environ_name: str - expected_type: Union[type, Tuple[type]] - allowed_values: Optional[Union[Tuple[Any, ...], List[Any]]] - default_value: Any - - def get_allowed_values(self) -> List[Any]: - """Get `allowed_values` or empty tuple if `allowed_values` is None. - - Returns - ------- - tuple - Allowable values if any. - """ - if self.allowed_values is None: - return [] - elif isinstance(self.allowed_values, list): - return self.allowed_values - elif isinstance( - self.allowed_values, collections.abc.Iterable - ) and not isinstance(self.allowed_values, str): - return list(self.allowed_values) - else: - return [self.allowed_values] - - def is_valid_param_value(self, value): - """Validate that a global configuration value is valid. - - Verifies that the value set for a global configuration parameter is valid - based on the its configuration settings. - - Returns - ------- - bool - Whether a parameter value is valid. - """ - allowed_values = self.get_allowed_values() - - valid_param: bool - if not isinstance(value, self.expected_type): - valid_param = False - elif allowed_values is not None and value not in allowed_values: - valid_param = False - else: - valid_param = True - return valid_param - - def get_valid_param_or_default(self, value, default_value=None, msg=None): - """Validate `value` and return default if it is not valid. - - Parameters - ---------- - value : Any - The configuration parameter value to set. - default_value : Any, default=None - An optional default value to use to set the configuration parameter - if `value` is not valid based on defined expected type and allowed - values. If None, and `value` is invalid then the classes `default_value` - parameter is used. - msg : str, default=None - An optional message to be used as start of the UserWarning message. - """ - if self.is_valid_param_value(value): - return value - else: - if msg is None: - msg = "" - msg + f"When setting global config values for `{self.name}`, the values " - msg += f"should be of type {self.expected_type}.\n" - if self.allowed_values is not None: - values_str = _format_seq_to_str( - self.get_allowed_values(), last_sep="or", remove_type_text=True - ) - msg += f"Allowed values are be one of {values_str}. " - msg += f"But found {value}." - warnings.warn(msg, UserWarning, stacklevel=2) - return default_value if default_value is not None else self.default_value - - -_CONFIG_REGISTRY: Dict[str, GlobalConfigParamSetting] = { - "print_changed_only": GlobalConfigParamSetting( - name="print_changed_only", - os_environ_name="SKBASE_PRINT_CHANGED_ONLY", - expected_type=bool, - allowed_values=(True, False), - default_value=True, - ), - "display": GlobalConfigParamSetting( - name="display", - os_environ_name="SKBASE_OBJECT_DISPLAY", - expected_type=str, - allowed_values=("text", "diagram"), - default_value="text", - ), -} - -_DEFAULT_GLOBAL_CONFIG: Dict[str, Any] = { - config_name: config_info.default_value - for config_name, config_info in _CONFIG_REGISTRY.items() -} - -global_config = _DEFAULT_GLOBAL_CONFIG.copy() -_THREAD_LOCAL_DATA = threading.local() - - -def _get_threadlocal_config() -> Dict[str, Any]: - """Get a threadlocal **mutable** configuration. - - If the configuration does not exist, copy the default global configuration. - - Returns - ------- - dict - Threadlocal global config or copy of default global configuration. - """ - if not hasattr(_THREAD_LOCAL_DATA, "global_config"): - _THREAD_LOCAL_DATA.global_config = global_config.copy() - return _THREAD_LOCAL_DATA.global_config - - -def get_default_config() -> Dict[str, Any]: - """Retrieve the default global configuration. - - This will always return the default ``skbase`` global configuration. - - Returns - ------- - config : dict - The default configurable settings (keys) and their default values (values). - - See Also - -------- - config_context : - Configuration context manager. - get_config : - Retrieve current global configuration values. - set_config : - Set global configuration. - reset_config : - Reset configuration to ``skbase`` default. - - Examples - -------- - >>> from skbase.config import get_default_config - >>> get_default_config() - {'print_changed_only': True, 'display': 'text'} - """ - return _DEFAULT_GLOBAL_CONFIG.copy() - - -def get_config() -> Dict[str, Any]: - """Retrieve current values for configuration set by :meth:`set_config`. - - Will return the default configuration if know updated configuration has - been set by :meth:`set_config`. - - Returns - ------- - config : dict - The configurable settings (keys) and their default values (values). - - See Also - -------- - config_context : - Configuration context manager. - get_default_config : - Retrieve ``skbase``'s default configuration. - set_config : - Set global configuration. - reset_config : - Reset configuration to ``skbase`` default. - - Examples - -------- - >>> from skbase.config import get_config - >>> get_config() - {'print_changed_only': True, 'display': 'text'} - """ - return _get_threadlocal_config().copy() - - -def set_config( - print_changed_only: Optional[bool] = None, - display: Literal["text", "diagram"] = None, - local_threadsafe: bool = False, -) -> None: - """Set global configuration. - - Allows the ``skbase`` global configuration to be updated. - - Parameters - ---------- - print_changed_only : bool, default=None - If True, only the parameters that were set to non-default - values will be printed when printing a BaseObject instance. For example, - ``print(SVC())`` while True will only print 'SVC()', but would print - 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters - when False. If None, the existing value won't change. - display : {'text', 'diagram'}, default=None - If 'diagram', instances inheritting from BaseOBject will be displayed - as a diagram in a Jupyter lab or notebook context. If 'text', instances - inheritting from BaseObject will be displayed as text. If None, the - existing value won't change. - local_threadsafe : bool, default=False - If False, set the backend as default for all threads. - - Returns - ------- - None - No output returned. - - See Also - -------- - config_context : - Configuration context manager. - get_default_config : - Retrieve ``skbase``'s default configuration. - get_config : - Retrieve current global configuration values. - reset_config : - Reset configuration to default. - - Examples - -------- - >>> from skbase.config import get_config, set_config - >>> get_config() - {'print_changed_only': True, 'display': 'text'} - >>> set_config(display='diagram') - >>> get_config() - {'print_changed_only': True, 'display': 'diagram'} - """ - local_config = _get_threadlocal_config() - - msg = "Attempting to set an invalid value for a global configuration.\n" - msg += "Using current configuration value of parameter as a result.\n" - if print_changed_only is not None: - local_config["print_changed_only"] = _CONFIG_REGISTRY[ - "print_changed_only" - ].get_valid_param_or_default( - print_changed_only, - default_value=local_config["print_changed_only"], - msg=msg, - ) - if display is not None: - local_config["display"] = _CONFIG_REGISTRY[ - "display" - ].get_valid_param_or_default( - display, default_value=local_config["display"], msg=msg - ) - - if not local_threadsafe: - global_config.update(local_config) - - return None - - -def reset_config() -> None: - """Reset the global configuration to the default. - - Will remove any user updates to the global configuration and reset the values - back to the ``skbase`` defaults. - - Returns - ------- - None - No output returned. - - See Also - -------- - config_context : - Configuration context manager. - get_default_config : - Retrieve ``skbase``'s default configuration. - get_config : - Retrieve current global configuration values. - set_config : - Set global configuration. - - Examples - -------- - >>> from skbase.config import get_config, set_config, reset_config - >>> get_config() - {'print_changed_only': True, 'display': 'text'} - >>> set_config(display='diagram') - >>> get_config() - {'print_changed_only': True, 'display': 'diagram'} - >>> reset_config() - >>> get_config() - {'print_changed_only': True, 'display': 'text'} - """ - default_config = get_default_config() - set_config(**default_config) - return None - - -@contextmanager -def config_context( - print_changed_only: Optional[bool] = None, - display: Literal["text", "diagram"] = None, - local_threadsafe: bool = False, -) -> Iterator[None]: - """Context manager for global configuration. - - Provides the ability to run code using different configuration without - having to update the global config. - - Parameters - ---------- - print_changed_only : bool, default=None - If True, only the parameters that were set to non-default - values will be printed when printing a BaseObject instance. For example, - ``print(SVC())`` while True will only print 'SVC()', but would print - 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters - when False. If None, the existing value won't change. - display : {'text', 'diagram'}, default=None - If 'diagram', instances inheritting from BaseOBject will be displayed - as a diagram in a Jupyter lab or notebook context. If 'text', instances - inheritting from BaseObject will be displayed as text. If None, the - existing value won't change. - local_threadsafe : bool, default=False - If False, set the config as default for all threads. - - Yields - ------ - None - - See Also - -------- - get_default_config : - Retrieve ``skbase``'s default configuration. - get_config : - Retrieve current values of the global configuration. - set_config : - Set global configuration. - reset_config : - Reset configuration to ``skbase`` default. - - Notes - ----- - All settings, not just those presently modified, will be returned to - their previous values when the context manager is exited. - - Examples - -------- - >>> from skbase.config import config_context - >>> with config_context(display='diagram'): - ... pass - """ - old_config = get_config() - set_config( - print_changed_only=print_changed_only, - display=display, - local_threadsafe=local_threadsafe, - ) - - try: - yield - finally: - set_config(**old_config) +__all__: List[str] = [ + "GlobalConfigParamSetting", + "get_default_config", + "get_config", + "set_config", + "reset_config", + "config_context", +] diff --git a/skbase/config/_config.py b/skbase/config/_config.py new file mode 100644 index 00000000..ec2fc13c --- /dev/null +++ b/skbase/config/_config.py @@ -0,0 +1,397 @@ +# -*- coding: utf-8 -*- +"""Implement logic for global configuration of skbase.""" +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Includes functionality like get_config, set_config, and config_context +# that is similar to scikit-learn. These elements are copyrighted by the +# scikit-learn developers, BSD-3-Clause License. For conditions see +# https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +import collections +import sys +import threading +import warnings +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal # type: ignore + +from skbase.utils._iter import _format_seq_to_str + +__author__: List[str] = ["RNKuhns"] +__all__: List[str] = [ + "GlobalConfigParamSetting", + "get_default_config", + "get_config", + "set_config", + "reset_config", + "config_context", +] + + +@dataclass +class GlobalConfigParamSetting: + """Define types of the setting information for a given config parameter.""" + + name: str + os_environ_name: str + expected_type: Union[type, Tuple[type]] + allowed_values: Optional[Union[Tuple[Any, ...], List[Any]]] + default_value: Any + + def get_allowed_values(self) -> List[Any]: + """Get `allowed_values` or empty tuple if `allowed_values` is None. + + Returns + ------- + tuple + Allowable values if any. + """ + if self.allowed_values is None: + return [] + elif isinstance(self.allowed_values, list): + return self.allowed_values + elif isinstance( + self.allowed_values, collections.abc.Iterable + ) and not isinstance(self.allowed_values, str): + return list(self.allowed_values) + else: + return [self.allowed_values] + + def is_valid_param_value(self, value): + """Validate that a global configuration value is valid. + + Verifies that the value set for a global configuration parameter is valid + based on the its configuration settings. + + Returns + ------- + bool + Whether a parameter value is valid. + """ + allowed_values = self.get_allowed_values() + + valid_param: bool + if not isinstance(value, self.expected_type): + valid_param = False + elif allowed_values is not None and value not in allowed_values: + valid_param = False + else: + valid_param = True + return valid_param + + def get_valid_param_or_default(self, value, default_value=None, msg=None): + """Validate `value` and return default if it is not valid. + + Parameters + ---------- + value : Any + The configuration parameter value to set. + default_value : Any, default=None + An optional default value to use to set the configuration parameter + if `value` is not valid based on defined expected type and allowed + values. If None, and `value` is invalid then the classes `default_value` + parameter is used. + msg : str, default=None + An optional message to be used as start of the UserWarning message. + """ + if self.is_valid_param_value(value): + return value + else: + if msg is None: + msg = "" + msg + f"When setting global config values for `{self.name}`, the values " + msg += f"should be of type {self.expected_type}.\n" + if self.allowed_values is not None: + values_str = _format_seq_to_str( + self.get_allowed_values(), last_sep="or", remove_type_text=True + ) + msg += f"Allowed values are be one of {values_str}. " + msg += f"But found {value}." + warnings.warn(msg, UserWarning, stacklevel=2) + return default_value if default_value is not None else self.default_value + + +_CONFIG_REGISTRY: Dict[str, GlobalConfigParamSetting] = { + "print_changed_only": GlobalConfigParamSetting( + name="print_changed_only", + os_environ_name="SKBASE_PRINT_CHANGED_ONLY", + expected_type=bool, + allowed_values=(True, False), + default_value=True, + ), + "display": GlobalConfigParamSetting( + name="display", + os_environ_name="SKBASE_OBJECT_DISPLAY", + expected_type=str, + allowed_values=("text", "diagram"), + default_value="text", + ), +} + +_DEFAULT_GLOBAL_CONFIG: Dict[str, Any] = { + config_name: config_info.default_value + for config_name, config_info in _CONFIG_REGISTRY.items() +} + +global_config = _DEFAULT_GLOBAL_CONFIG.copy() +_THREAD_LOCAL_DATA = threading.local() + + +def _get_threadlocal_config() -> Dict[str, Any]: + """Get a threadlocal **mutable** configuration. + + If the configuration does not exist, copy the default global configuration. + + Returns + ------- + dict + Threadlocal global config or copy of default global configuration. + """ + if not hasattr(_THREAD_LOCAL_DATA, "global_config"): + _THREAD_LOCAL_DATA.global_config = global_config.copy() + return _THREAD_LOCAL_DATA.global_config + + +def get_default_config() -> Dict[str, Any]: + """Retrieve the default global configuration. + + This will always return the default ``skbase`` global configuration. + + Returns + ------- + config : dict + The default configurable settings (keys) and their default values (values). + + See Also + -------- + config_context : + Configuration context manager. + get_config : + Retrieve current global configuration values. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Examples + -------- + >>> from skbase.config import get_default_config + >>> get_default_config() + {'print_changed_only': True, 'display': 'text'} + """ + return _DEFAULT_GLOBAL_CONFIG.copy() + + +def get_config() -> Dict[str, Any]: + """Retrieve current values for configuration set by :meth:`set_config`. + + Will return the default configuration if know updated configuration has + been set by :meth:`set_config`. + + Returns + ------- + config : dict + The configurable settings (keys) and their default values (values). + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Examples + -------- + >>> from skbase.config import get_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + """ + return _get_threadlocal_config().copy() + + +def set_config( + print_changed_only: Optional[bool] = None, + display: Literal["text", "diagram"] = None, + local_threadsafe: bool = False, +) -> None: + """Set global configuration. + + Allows the ``skbase`` global configuration to be updated. + + Parameters + ---------- + print_changed_only : bool, default=None + If True, only the parameters that were set to non-default + values will be printed when printing a BaseObject instance. For example, + ``print(SVC())`` while True will only print 'SVC()', but would print + 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters + when False. If None, the existing value won't change. + display : {'text', 'diagram'}, default=None + If 'diagram', instances inheritting from BaseOBject will be displayed + as a diagram in a Jupyter lab or notebook context. If 'text', instances + inheritting from BaseObject will be displayed as text. If None, the + existing value won't change. + local_threadsafe : bool, default=False + If False, set the backend as default for all threads. + + Returns + ------- + None + No output returned. + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current global configuration values. + reset_config : + Reset configuration to default. + + Examples + -------- + >>> from skbase.config import get_config, set_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + >>> set_config(display='diagram') + >>> get_config() + {'print_changed_only': True, 'display': 'diagram'} + """ + local_config = _get_threadlocal_config() + + msg = "Attempting to set an invalid value for a global configuration.\n" + msg += "Using current configuration value of parameter as a result.\n" + if print_changed_only is not None: + local_config["print_changed_only"] = _CONFIG_REGISTRY[ + "print_changed_only" + ].get_valid_param_or_default( + print_changed_only, + default_value=local_config["print_changed_only"], + msg=msg, + ) + if display is not None: + local_config["display"] = _CONFIG_REGISTRY[ + "display" + ].get_valid_param_or_default( + display, default_value=local_config["display"], msg=msg + ) + + if not local_threadsafe: + global_config.update(local_config) + + return None + + +def reset_config() -> None: + """Reset the global configuration to the default. + + Will remove any user updates to the global configuration and reset the values + back to the ``skbase`` defaults. + + Returns + ------- + None + No output returned. + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current global configuration values. + set_config : + Set global configuration. + + Examples + -------- + >>> from skbase.config import get_config, set_config, reset_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + >>> set_config(display='diagram') + >>> get_config() + {'print_changed_only': True, 'display': 'diagram'} + >>> reset_config() + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + """ + default_config = get_default_config() + set_config(**default_config) + return None + + +@contextmanager +def config_context( + print_changed_only: Optional[bool] = None, + display: Literal["text", "diagram"] = None, + local_threadsafe: bool = False, +) -> Iterator[None]: + """Context manager for global configuration. + + Provides the ability to run code using different configuration without + having to update the global config. + + Parameters + ---------- + print_changed_only : bool, default=None + If True, only the parameters that were set to non-default + values will be printed when printing a BaseObject instance. For example, + ``print(SVC())`` while True will only print 'SVC()', but would print + 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters + when False. If None, the existing value won't change. + display : {'text', 'diagram'}, default=None + If 'diagram', instances inheritting from BaseOBject will be displayed + as a diagram in a Jupyter lab or notebook context. If 'text', instances + inheritting from BaseObject will be displayed as text. If None, the + existing value won't change. + local_threadsafe : bool, default=False + If False, set the config as default for all threads. + + Yields + ------ + None + + See Also + -------- + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current values of the global configuration. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Notes + ----- + All settings, not just those presently modified, will be returned to + their previous values when the context manager is exited. + + Examples + -------- + >>> from skbase.config import config_context + >>> with config_context(display='diagram'): + ... pass + """ + old_config = get_config() + set_config( + print_changed_only=print_changed_only, + display=display, + local_threadsafe=local_threadsafe, + ) + + try: + yield + finally: + set_config(**old_config) diff --git a/skbase/config/tests/test_config.py b/skbase/config/tests/test_config.py index 0569125e..1af163e4 100644 --- a/skbase/config/tests/test_config.py +++ b/skbase/config/tests/test_config.py @@ -3,8 +3,6 @@ import pytest from skbase.config import ( - _CONFIG_REGISTRY, - _DEFAULT_GLOBAL_CONFIG, GlobalConfigParamSetting, config_context, get_config, @@ -12,6 +10,7 @@ reset_config, set_config, ) +from skbase.config._config import _CONFIG_REGISTRY, _DEFAULT_GLOBAL_CONFIG PRINT_CHANGE_ONLY_VALUES = _CONFIG_REGISTRY["print_changed_only"].get_allowed_values() DISPLAY_VALUES = _CONFIG_REGISTRY["display"].get_allowed_values() diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index f4e0e847..a39e2251 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -24,6 +24,7 @@ "skbase.base._meta", "skbase.base._tagmanager", "skbase.config", + "skbase.config._config", "skbase.lookup", "skbase.lookup.tests", "skbase.lookup.tests.test_lookup", @@ -77,6 +78,7 @@ "skbase.base._base": ("BaseEstimator", "BaseObject"), "skbase.base._meta": ("BaseMetaEstimator",), "skbase.config": ("GlobalConfigParamSetting",), + "skbase.config._config": ("GlobalConfigParamSetting",), "skbase.lookup._lookup": ("ClassInfo", "FunctionInfo", "ModuleInfo"), "skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"), "skbase.testing.test_all_objects": ( @@ -99,6 +101,13 @@ "reset_config", "config_context", ), + "skbase.config._config": ( + "get_config", + "get_default_config", + "set_config", + "reset_config", + "config_context", + ), "skbase.lookup": ("all_objects", "get_package_metadata"), "skbase.lookup._lookup": ("all_objects", "get_package_metadata"), "skbase.testing.utils._conditional_fixtures": ( @@ -133,7 +142,7 @@ SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy() SKBASE_FUNCTIONS_BY_MODULE.update( { - "skbase.config": ( + "skbase.config._config": ( "_get_threadlocal_config", "get_config", "get_default_config", From 64819441c42d43da9b12339a8c2be15998ca7fc4 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 02:35:25 -0400 Subject: [PATCH 14/16] Add metadata to pass metadata functionality tests --- skbase/tests/conftest.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 0fbd8025..f0cf89b1 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -22,6 +22,9 @@ "skbase.base", "skbase.base._base", "skbase.base._meta", + "skbase.base._pretty_printing", + "skbase.base._pretty_printing._object_html_repr", + "skbase.base._pretty_printing._pprint", "skbase.base._tagmanager", "skbase.config", "skbase.config._config", @@ -78,6 +81,7 @@ "skbase.base": ("BaseEstimator", "BaseMetaEstimator", "BaseObject"), "skbase.base._base": ("BaseEstimator", "BaseObject"), "skbase.base._meta": ("BaseMetaEstimator",), + "skbase.base._pretty_printing._pprint": ("KeyValTuple", "KeyValTupleParam"), "skbase.config": ("GlobalConfigParamSetting",), "skbase.config._config": ("GlobalConfigParamSetting",), "skbase.lookup._lookup": ("ClassInfo", "FunctionInfo", "ModuleInfo"), @@ -91,6 +95,12 @@ SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy() SKBASE_CLASSES_BY_MODULE.update( { + "skbase.base._pretty_printing._object_html_repr": ("_VisualBlock",), + "skbase.base._pretty_printing._pprint": ( + "KeyValTuple", + "KeyValTupleParam", + "_BaseObjectPrettyPrinter", + ), "skbase.base._tagmanager": ("_FlagManager",), } ) @@ -143,6 +153,13 @@ SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy() SKBASE_FUNCTIONS_BY_MODULE.update( { + "skbase.base._pretty_printing._object_html_repr": ( + "_get_visual_block", + "_object_html_repr", + "_write_base_object_html", + "_write_label_html", + ), + "skbase.base._pretty_printing._pprint": ("_changed_params", "_safe_repr"), "skbase.config._config": ( "_get_threadlocal_config", "get_config", From c47e96dba3363eac3cbf80c32eb64b78e7bdda58 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 02:41:46 -0400 Subject: [PATCH 15/16] Use call to local config not global config for pretty printing configuration --- skbase/base/_base.py | 8 ++++---- skbase/base/_pretty_printing/_object_html_repr.py | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/skbase/base/_base.py b/skbase/base/_base.py index eceebb10..1393e267 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -739,7 +739,7 @@ def __repr__(self, n_char_max: int = 700): indent=1, indent_at_name=True, n_max_elements_to_show=n_max_elements_to_show, - changed_only=get_config()["print_changed_only"], + changed_only=self.get_config()["print_changed_only"], ) repr_ = pp.pformat(self) @@ -787,7 +787,7 @@ def _repr_html_(self): should be favorted in the long term, `_repr_html_` is only implemented for consumers who do not interpret `_repr_mimbundle_`. """ - if get_config()["display"] != "diagram": + if self.get_config()["display"] != "diagram": raise AttributeError( "_repr_html_ is only defined when the " "`display` configuration option is set to 'diagram'." @@ -799,14 +799,14 @@ def _repr_html_inner(self): This function is returned by the @property `_repr_html_` to make `hasattr(BaseObject, "_repr_html_") return `True` or `False` depending - on `get_config()["display"]`. + on `self.get_config()["display"]`. """ return _object_html_repr(self) def _repr_mimebundle_(self, **kwargs): """Mime bundle used by jupyter kernels to display instances of BaseObject.""" output = {"text/plain": repr(self)} - if get_config()["display"] == "diagram": + if self.get_config()["display"] == "diagram": output["text/html"] = _object_html_repr(self) return output diff --git a/skbase/base/_pretty_printing/_object_html_repr.py b/skbase/base/_pretty_printing/_object_html_repr.py index 851061c0..16515b9d 100644 --- a/skbase/base/_pretty_printing/_object_html_repr.py +++ b/skbase/base/_pretty_printing/_object_html_repr.py @@ -128,6 +128,8 @@ def _write_base_object_html( if first_call: est_block = _get_visual_block(base_object) else: + # So it is easier to read, always use print_changed_only==True + # regardless of configuration with config_context(print_changed_only=True): est_block = _get_visual_block(base_object) From 3c49f7d165c488298ffbf7a858a137d9b467f5f9 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Sun, 19 Mar 2023 18:05:04 -0400 Subject: [PATCH 16/16] Add more pretty printing test cases --- skbase/tests/test_base.py | 44 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index 6284526c..28d54d5a 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -914,16 +914,58 @@ def test_baseobject_repr( with config_context(print_changed_only=False, display="text"): assert repr(base_obj) == "Parent(a='something', b=7, c=None)" + # Check that local config works as expected + base_obj.set_config(print_changed_only=False) + assert repr(base_obj) == "Parent(a='something', b=7, c=None)" + + # Test with dict parameter (note that dict is sorted by keys when printed) + # not printed in order it was created + base_obj = fixture_class_parent(c={"c": 1, "a": 2}) + assert repr(base_obj) == "Parent(c={'a': 2, 'c': 1})" + + # Now test when one params values are named object tuples + named_objs = [ + ("step 1", fixture_class_parent()), + ("step 2", fixture_class_parent()), + ] + base_obj = fixture_class_parent(c=named_objs) + assert repr(base_obj) == "Parent(c=[('step 1', Parent()), ('step 2', Parent())])" + + # Or when they are just lists of tuples or just tuples as param + base_obj = fixture_class_parent(c=[("one", 1), ("two", 2)]) + assert repr(base_obj) == "Parent(c=[('one', 1), ('two', 2)])" + + base_obj = fixture_class_parent(c=(1, 2, 3)) + assert repr(base_obj) == "Parent(c=(1, 2, 3))" + simple_composite = fixture_composition_dummy(foo=fixture_class_parent()) assert repr(simple_composite) == "CompositionDummy(foo=Parent())" long_base_obj_repr = fixture_class_parent(a=["long_params"] * 1000) assert len(repr(long_base_obj_repr)) == 535 + named_objs = [(f"Step {i+1}", Child()) for i in range(25)] + base_comp = CompositionDummy(foo=Parent(c=Child(c=named_objs))) + assert len(repr(base_comp)) == 1362 + with config_context(print_changed_only=False): + assert "..." in repr(base_comp) + def test_baseobject_str(fixture_class_parent_instance: Parent): """Test BaseObject string representation works.""" - str(fixture_class_parent_instance) + assert ( + str(fixture_class_parent_instance) == "Parent()" + ), "String representation of instance not working." + # Check that we can alter the detail about params that is printed + # using config_context with ``print_changed_only=False`` + with config_context(print_changed_only=False, display="text"): + assert ( + str(fixture_class_parent_instance) == "Parent(a='something', b=7, c=None)" + ) + + # Check that local config works as expected + fixture_class_parent_instance.set_config(print_changed_only=False) + assert str(fixture_class_parent_instance) == "Parent(a='something', b=7, c=None)" def test_baseobject_repr_mimebundle_(fixture_class_parent_instance: Parent):