From 5ad1707e9e692b6b4bc773c50d476e2ef2d89084 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Sun, 12 Mar 2023 21:21:21 -0400 Subject: [PATCH 01/12] Add global config --- skbase/config/__init__.py | 404 +++++++++++++++++++++++++++++ skbase/config/tests/__init__.py | 4 + skbase/config/tests/test_config.py | 154 +++++++++++ 3 files changed, 562 insertions(+) create mode 100644 skbase/config/__init__.py create mode 100644 skbase/config/tests/__init__.py create mode 100644 skbase/config/tests/test_config.py diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py new file mode 100644 index 00000000..3f38839b --- /dev/null +++ b/skbase/config/__init__.py @@ -0,0 +1,404 @@ +# -*- coding: utf-8 -*- +""":mod:`skbase.config` provides tools for global configuration of ``skbase``.""" +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Includes functionality like get_config, set_config, and config_context +# that is similar to scikit-learn. These elements are copyrighted by the +# scikit-learn developers, BSD-3-Clause License. For conditions see +# https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +import collections +import sys +import threading +import warnings +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal # type: ignore + +from skbase.utils._iter import _format_seq_to_str + +__author__: List[str] = ["RNKuhns"] + + +@dataclass +class GlobalConfigParamSetting: + """Define types of the setting information for a given config parameter.""" + + name: str + os_environ_name: str + expected_type: Union[type, Tuple[type]] + allowed_values: Optional[Tuple[Any, ...]] + default_value: Any + + def get_allowed_values(self): + """Get `allowed_values` or empty tuple if `allowed_values` is None. + + Returns + ------- + tuple + Allowable values if any. + """ + if self.allowed_values is None: + return () + elif isinstance(self.allowed_values, tuple): + return self.allowed_values + elif isinstance( + self.allowed_values, collections.abc.Iterable + ) and not isinstance(self.allowed_values, str): + return tuple(self.allowed_values) + else: + return (self.allowed_values,) + + def is_valid_param_value(self, value): + """Validate that a global configuration value is valid. + + Verifies that the value set for a global configuration parameter is valid + based on the its configuration settings. + + Returns + ------- + bool + Whether a parameter value is valid. + """ + allowed_values = self.get_allowed_values() + + valid_param: bool + if not isinstance(value, self.expected_type): + valid_param = False + elif allowed_values is not None and value not in allowed_values: + valid_param = False + else: + valid_param = True + return valid_param + + def get_valid_param_or_default(self, value): + """Validate `value` and return default if it is not valid.""" + if self.is_valid_param_value(value): + return value + else: + msg = "Attempting to set an invalid value for a global configuration.\n" + msg += "Using default configuration value of parameter as a result.\n" + msg + f"When setting global config values for `{self.name}`, the values " + msg += f"should be of type {self.expected_type}." + if self.allowed_values is not None: + values_str = _format_seq_to_str( + self.get_allowed_values(), last_sep="or", remove_type_text=True + ) + msg += f"Allowed values are be one of {values_str}." + warnings.warn(msg, UserWarning, stacklevel=2) + return self.default_value + + +_CONFIG_REGISTRY: Dict[str, GlobalConfigParamSetting] = { + "print_changed_only": GlobalConfigParamSetting( + name="print_changed_only", + os_environ_name="SKBASE_PRINT_CHANGED_ONLY", + expected_type=bool, + allowed_values=(True, False), + default_value=True, + ), + "display": GlobalConfigParamSetting( + name="display", + os_environ_name="SKBASE_OBJECT_DISPLAY", + expected_type=str, + allowed_values=("text", "diagram"), + default_value="text", + ), +} + +_DEFAULT_GLOBAL_CONFIG: Dict[str, Any] = { + config_name: config_info.default_value + for config_name, config_info in _CONFIG_REGISTRY.items() +} + +global_config = _DEFAULT_GLOBAL_CONFIG.copy() +_THREAD_LOCAL_DATA = threading.local() + + +def _get_threadlocal_config() -> Dict[str, Any]: + """Get a threadlocal **mutable** configuration. + + If the configuration does not exist, copy the default global configuration. + + Returns + ------- + dict + Threadlocal global config or copy of default global configuration. + """ + if not hasattr(_THREAD_LOCAL_DATA, "global_config"): + _THREAD_LOCAL_DATA.global_config = global_config.copy() + return _THREAD_LOCAL_DATA.global_config + + +def get_config_os_env_names() -> List[str]: + """Retrieve the os environment names for configurable settings. + + Returns + ------- + env_names : list + The os environment names that can be used to set configurable settings. + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current global configuration values. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Examples + -------- + >>> from skbase.config import get_config_os_env_names + >>> get_config_os_env_names() + ['SKBASE_PRINT_CHANGED_ONLY', 'SKBASE_OBJECT_DISPLAY'] + """ + return [config_info.os_environ_name for config_info in _CONFIG_REGISTRY.values()] + + +def get_default_config() -> Dict[str, Any]: + """Retrive the default global configuration. + + This will always return the default ``skbase`` global configuration. + + Returns + ------- + config : dict + The default configurable settings (keys) and their default values (values). + + See Also + -------- + config_context : + Configuration context manager. + get_config : + Retrieve current global configuration values. + get_config_os_env_names : + Retrieve os environment names that can be used to set configuration. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Examples + -------- + >>> from skbase.config import get_default_config + >>> get_default_config() + {'print_changed_only': True, 'display': 'text'} + """ + return _DEFAULT_GLOBAL_CONFIG.copy() + + +def get_config() -> Dict[str, Any]: + """Retrieve current values for configuration set by :meth:`set_config`. + + Willr return the default configuration if know updated configuration has + been set by :meth:`set_config`. + + Returns + ------- + config : dict + The configurable settings (keys) and their default values (values). + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config_os_env_names : + Retrieve os environment names that can be used to set configuration. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Examples + -------- + >>> from skbase.config import get_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + """ + return _get_threadlocal_config().copy() + + +def set_config( + print_changed_only: Optional[bool] = None, + display: Literal["text", "diagram"] = None, + local_threadsafe: bool = False, +) -> None: + """Set global configuration. + + Allows the ``skbase`` global configuration to be updated. + + Parameters + ---------- + print_changed_only : bool, default=None + If True, only the parameters that were set to non-default + values will be printed when printing a BaseObject instance. For example, + ``print(SVC())`` while True will only print 'SVC()', but would print + 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters + when False. If None, the existing value won't change. + display : {'text', 'diagram'}, default=None + If 'diagram', instances inheritting from BaseOBject will be displayed + as a diagram in a Jupyter lab or notebook context. If 'text', instances + inheritting from BaseObject will be displayed as text. If None, the + existing value won't change. + local_threadsafe : bool, default=False + If False, set the backend as default for all threads. + + Returns + ------- + None + No output returned. + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current global configuration values. + get_config_os_env_names : + Retrieve os environment names that can be used to set configuration. + reset_config : + Reset configuration to default. + + Examples + -------- + >>> from skbase.config import get_config, set_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + >>> set_config(display='diagram') + >>> get_config() + {'print_changed_only': True, 'display': 'diagram'} + """ + local_config = _get_threadlocal_config() + + if print_changed_only is not None: + local_config["print_changed_only"] = _CONFIG_REGISTRY[ + "print_changed_only" + ].get_valid_param_or_default(print_changed_only) + if display is not None: + local_config["display"] = _CONFIG_REGISTRY[ + "display" + ].get_valid_param_or_default(display) + + if not local_threadsafe: + global_config.update(local_config) + + return None + + +def reset_config() -> None: + """Reset the global configuration to the default. + + Will remove any user updates to the global configuration and reset the values + back to the ``skbase`` defaults. + + Returns + ------- + None + No output returned. + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current global configuration values. + get_config_os_env_names : + Retrieve os environment names that can be used to set configuration. + set_config : + Set global configuration. + + Examples + -------- + >>> from skbase.config import get_config, set_config, reset_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + >>> set_config(display='diagram') + >>> get_config() + {'print_changed_only': True, 'display': 'diagram'} + >>> reset_config() + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + """ + default_config = get_default_config() + set_config(**default_config) + return None + + +@contextmanager +def config_context( + print_changed_only: Optional[bool] = None, + display: Literal["text", "diagram"] = None, + local_threadsafe: bool = False, +) -> Iterator[None]: + """Context manager for global configuration. + + Parameters + ---------- + print_changed_only : bool, default=None + If True, only the parameters that were set to non-default + values will be printed when printing a BaseObject instance. For example, + ``print(SVC())`` while True will only print 'SVC()', but would print + 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters + when False. If None, the existing value won't change. + display : {'text', 'diagram'}, default=None + If 'diagram', instances inheritting from BaseOBject will be displayed + as a diagram in a Jupyter lab or notebook context. If 'text', instances + inheritting from BaseObject will be displayed as text. If None, the + existing value won't change. + local_threadsafe : bool, default=False + If False, set the config as default for all threads. + + Yields + ------ + None + + See Also + -------- + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current values of the global configuration. + get_config_os_env_names : + Retrieve os environment names that can be used to set configuration. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Notes + ----- + All settings, not just those presently modified, will be returned to + their previous values when the context manager is exited. + + Examples + -------- + >>> from skbase.config import config_context + >>> with config_context(display='diagram'): + ... pass + """ + old_config = get_config() + set_config( + print_changed_only=print_changed_only, + display=display, + local_threadsafe=local_threadsafe, + ) + + try: + yield + finally: + set_config(**old_config) diff --git a/skbase/config/tests/__init__.py b/skbase/config/tests/__init__.py new file mode 100644 index 00000000..ac71ad0c --- /dev/null +++ b/skbase/config/tests/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python3 -u +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +"""Test functionality of :mod:`skbase.config`.""" diff --git a/skbase/config/tests/test_config.py b/skbase/config/tests/test_config.py new file mode 100644 index 00000000..67a104ff --- /dev/null +++ b/skbase/config/tests/test_config.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- +"""Test configuration functionality.""" +import pytest + +from skbase.config import ( + _CONFIG_REGISTRY, + _DEFAULT_GLOBAL_CONFIG, + GlobalConfigParamSetting, + config_context, + get_config, + get_config_os_env_names, + get_default_config, + reset_config, + set_config, +) + +PRINT_CHANGE_ONLY_VALUES = _CONFIG_REGISTRY["print_changed_only"].get_allowed_values() +DISPLAY_VALUES = _CONFIG_REGISTRY["display"].get_allowed_values() + + +@pytest.fixture +def config_registry(): + """Config registry fixture.""" + return _CONFIG_REGISTRY + + +@pytest.fixture +def global_config_default(): + """Config registry fixture.""" + return _DEFAULT_GLOBAL_CONFIG + + +@pytest.mark.parametrize("allowed_values", (None, (), "something", range(1, 8))) +def test_global_config_param_get_allowed_values(allowed_values): + """Test GlobalConfigParamSetting behavior works as expected.""" + some_config_param = GlobalConfigParamSetting( + name="some_param", + os_environ_name="SKBASE_OBJECT_DISPLAY", + expected_type=str, + allowed_values=allowed_values, + default_value="text", + ) + # Verify we always coerce output of get_allowed_values to tuple + values = some_config_param.get_allowed_values() + assert isinstance(values, tuple) + + +@pytest.mark.parametrize("value", (None, (), "wrong_string", "text", range(1, 8))) +def test_global_config_param_is_valid_param_value(value): + """Test GlobalConfigParamSetting behavior works as expected.""" + some_config_param = GlobalConfigParamSetting( + name="some_param", + os_environ_name="SKBASE_OBJECT_DISPLAY", + expected_type=str, + allowed_values=("text", "diagram"), + default_value="text", + ) + # Verify we correctly identify invalid parameters + if value in ("text", "diagram"): + expected_valid = True + else: + expected_valid = False + assert some_config_param.is_valid_param_value(value) == expected_valid + + +def test_get_config_os_env_names(config_registry): + """Verify that get_config_os_env_names returns expected values.""" + os_env_names = get_config_os_env_names() + expected_os_env_names = [ + config_info.os_environ_name for config_info in config_registry.values() + ] + msg = "`get_config_os_env_names` does not return expected value.\n" + msg += f"Expected {expected_os_env_names}, but returned {os_env_names}." + assert os_env_names == expected_os_env_names, msg + + +def test_get_default_config(global_config_default): + """Verify that get_default_config returns default global config values.""" + retrieved_default = get_default_config() + msg = "`get_default_config` does not return expected values.\n" + msg += f"Expected {global_config_default}, but returned {retrieved_default}." + assert retrieved_default == global_config_default, msg + + +@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) +@pytest.mark.parametrize("display", DISPLAY_VALUES) +def test_set_config_then_get_config_returns_expected_value(print_changed_only, display): + """Verify that get_config returns set config values if set_config run.""" + set_config(print_changed_only=print_changed_only, display=display) + retrieved_default = get_config() + expected_config = {"print_changed_only": print_changed_only, "display": display} + msg = "`get_config` used after `set_config` does not return expected values.\n" + msg += "After set_config is run, get_config should return the set values.\n " + msg += f"Expected {expected_config}, but returned {retrieved_default}." + assert retrieved_default == expected_config, msg + + +@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) +@pytest.mark.parametrize("display", DISPLAY_VALUES) +def test_reset_config_resets_the_config(print_changed_only, display): + """Verify that get_config returns default config if reset_config run.""" + default_config = get_default_config() + set_config(print_changed_only=print_changed_only, display=display) + reset_config() + retrieved_config = get_config() + + msg = "`get_config` does not return expected values after `reset_config`.\n" + msg += "`After reset_config is run, get_config` should return defaults.\n" + msg += f"Expected {default_config}, but returned {retrieved_config}." + assert retrieved_config == default_config, msg + reset_config() + + +@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) +@pytest.mark.parametrize("display", DISPLAY_VALUES) +def test_config_context(print_changed_only, display): + """Verify that config_context affects context but not overall configuration.""" + # Make sure config is reset to default values then retrieve it + reset_config() + retrieved_config = get_config() + # Now lets make sure the config_context is changing the context of those values + # within the scope of the context manager as expected + for print_changed_only in (True, False): + with config_context(print_changed_only=print_changed_only, display=display): + retrieved_context_config = get_config() + expected_config = {"print_changed_only": print_changed_only, "display": display} + msg = "`get_config` does not return expected values within `config_context`.\n" + msg += "`get_config` should return config defined by `config_context`.\n" + msg += f"Expected {expected_config}, but returned {retrieved_context_config}." + assert retrieved_context_config == expected_config, msg + + # Outside of the config_context we should have not affected the retrieved config + # set by call to reset_config() + config_post_config_context = get_config() + msg = "`get_config` does not return expected values after `config_context`a.\n" + msg += "`config_context` should not affect configuration outside its context.\n" + msg += f"Expected {config_post_config_context}, but returned {retrieved_config}." + assert retrieved_config == config_post_config_context, msg + reset_config() + + +def test_set_config_behavior_invalid_value(): + """Test set_config uses default and raises warning when setting invalid value.""" + reset_config() + with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"): + set_config(print_changed_only="False") + + assert get_config() == get_default_config() + + with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"): + set_config(print_changed_only=7) + + assert get_config() == get_default_config() + reset_config() From 34a9c5acb5e9ae5525d0fe083a155c8622cf8610 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Sun, 12 Mar 2023 23:51:10 -0400 Subject: [PATCH 02/12] Update global config tests --- skbase/config/__init__.py | 40 ------------------------------ skbase/config/tests/test_config.py | 30 +++++++--------------- skbase/tests/conftest.py | 19 +++++++++++++- 3 files changed, 27 insertions(+), 62 deletions(-) diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py index 3f38839b..55d58a3c 100644 --- a/skbase/config/__init__.py +++ b/skbase/config/__init__.py @@ -134,36 +134,6 @@ def _get_threadlocal_config() -> Dict[str, Any]: return _THREAD_LOCAL_DATA.global_config -def get_config_os_env_names() -> List[str]: - """Retrieve the os environment names for configurable settings. - - Returns - ------- - env_names : list - The os environment names that can be used to set configurable settings. - - See Also - -------- - config_context : - Configuration context manager. - get_default_config : - Retrieve ``skbase``'s default configuration. - get_config : - Retrieve current global configuration values. - set_config : - Set global configuration. - reset_config : - Reset configuration to ``skbase`` default. - - Examples - -------- - >>> from skbase.config import get_config_os_env_names - >>> get_config_os_env_names() - ['SKBASE_PRINT_CHANGED_ONLY', 'SKBASE_OBJECT_DISPLAY'] - """ - return [config_info.os_environ_name for config_info in _CONFIG_REGISTRY.values()] - - def get_default_config() -> Dict[str, Any]: """Retrive the default global configuration. @@ -180,8 +150,6 @@ def get_default_config() -> Dict[str, Any]: Configuration context manager. get_config : Retrieve current global configuration values. - get_config_os_env_names : - Retrieve os environment names that can be used to set configuration. set_config : Set global configuration. reset_config : @@ -213,8 +181,6 @@ def get_config() -> Dict[str, Any]: Configuration context manager. get_default_config : Retrieve ``skbase``'s default configuration. - get_config_os_env_names : - Retrieve os environment names that can be used to set configuration. set_config : Set global configuration. reset_config : @@ -267,8 +233,6 @@ def set_config( Retrieve ``skbase``'s default configuration. get_config : Retrieve current global configuration values. - get_config_os_env_names : - Retrieve os environment names that can be used to set configuration. reset_config : Reset configuration to default. @@ -317,8 +281,6 @@ def reset_config() -> None: Retrieve ``skbase``'s default configuration. get_config : Retrieve current global configuration values. - get_config_os_env_names : - Retrieve os environment names that can be used to set configuration. set_config : Set global configuration. @@ -373,8 +335,6 @@ def config_context( Retrieve ``skbase``'s default configuration. get_config : Retrieve current values of the global configuration. - get_config_os_env_names : - Retrieve os environment names that can be used to set configuration. set_config : Set global configuration. reset_config : diff --git a/skbase/config/tests/test_config.py b/skbase/config/tests/test_config.py index 67a104ff..11f3b5cf 100644 --- a/skbase/config/tests/test_config.py +++ b/skbase/config/tests/test_config.py @@ -8,7 +8,6 @@ GlobalConfigParamSetting, config_context, get_config, - get_config_os_env_names, get_default_config, reset_config, set_config, @@ -63,23 +62,11 @@ def test_global_config_param_is_valid_param_value(value): assert some_config_param.is_valid_param_value(value) == expected_valid -def test_get_config_os_env_names(config_registry): - """Verify that get_config_os_env_names returns expected values.""" - os_env_names = get_config_os_env_names() - expected_os_env_names = [ - config_info.os_environ_name for config_info in config_registry.values() - ] - msg = "`get_config_os_env_names` does not return expected value.\n" - msg += f"Expected {expected_os_env_names}, but returned {os_env_names}." - assert os_env_names == expected_os_env_names, msg - - def test_get_default_config(global_config_default): - """Verify that get_default_config returns default global config values.""" - retrieved_default = get_default_config() - msg = "`get_default_config` does not return expected values.\n" - msg += f"Expected {global_config_default}, but returned {retrieved_default}." - assert retrieved_default == global_config_default, msg + """Test get_default_config alwasy returns the default config.""" + assert get_default_config() == global_config_default + set_config(print_changed_only=False) + assert get_default_config() == global_config_default @pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) @@ -97,17 +84,18 @@ def test_set_config_then_get_config_returns_expected_value(print_changed_only, d @pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) @pytest.mark.parametrize("display", DISPLAY_VALUES) -def test_reset_config_resets_the_config(print_changed_only, display): +def test_reset_config_resets_the_config( + print_changed_only, display, global_config_default +): """Verify that get_config returns default config if reset_config run.""" - default_config = get_default_config() set_config(print_changed_only=print_changed_only, display=display) reset_config() retrieved_config = get_config() msg = "`get_config` does not return expected values after `reset_config`.\n" msg += "`After reset_config is run, get_config` should return defaults.\n" - msg += f"Expected {default_config}, but returned {retrieved_config}." - assert retrieved_config == default_config, msg + msg += f"Expected {global_config_default}, but returned {retrieved_config}." + assert retrieved_config == global_config_default, msg reset_config() diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 4f36034a..f4e0e847 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -23,6 +23,7 @@ "skbase.base._base", "skbase.base._meta", "skbase.base._tagmanager", + "skbase.config", "skbase.lookup", "skbase.lookup.tests", "skbase.lookup.tests.test_lookup", @@ -51,6 +52,7 @@ SKBASE_PUBLIC_MODULES = ( "skbase", "skbase.base", + "skbase.config", "skbase.lookup", "skbase.lookup.tests", "skbase.lookup.tests.test_lookup", @@ -74,6 +76,7 @@ "skbase.base": ("BaseEstimator", "BaseMetaEstimator", "BaseObject"), "skbase.base._base": ("BaseEstimator", "BaseObject"), "skbase.base._meta": ("BaseMetaEstimator",), + "skbase.config": ("GlobalConfigParamSetting",), "skbase.lookup._lookup": ("ClassInfo", "FunctionInfo", "ModuleInfo"), "skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"), "skbase.testing.test_all_objects": ( @@ -85,11 +88,17 @@ SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy() SKBASE_CLASSES_BY_MODULE.update( { - "skbase.base._meta": ("BaseMetaEstimator",), "skbase.base._tagmanager": ("_FlagManager",), } ) SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = { + "skbase.config": ( + "get_config", + "get_default_config", + "set_config", + "reset_config", + "config_context", + ), "skbase.lookup": ("all_objects", "get_package_metadata"), "skbase.lookup._lookup": ("all_objects", "get_package_metadata"), "skbase.testing.utils._conditional_fixtures": ( @@ -124,6 +133,14 @@ SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy() SKBASE_FUNCTIONS_BY_MODULE.update( { + "skbase.config": ( + "_get_threadlocal_config", + "get_config", + "get_default_config", + "set_config", + "reset_config", + "config_context", + ), "skbase.lookup._lookup": ( "_determine_module_path", "_get_return_tags", From fc5bee32ff0d2d6481ca343401baeae417fcadde Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 00:24:06 -0400 Subject: [PATCH 03/12] Add updated tests of local configs --- skbase/base/_base.py | 24 ++- skbase/tests/test_base.py | 321 ++++++++++++++++++++++++++++++-------- 2 files changed, 275 insertions(+), 70 deletions(-) diff --git a/skbase/base/_base.py b/skbase/base/_base.py index 970bf350..a7216434 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -63,6 +63,7 @@ class name: BaseEstimator from skbase._exceptions import NotFittedError from skbase.base._tagmanager import _FlagManager +from skbase.config import get_config __author__: List[str] = ["mloning", "RNKuhns", "fkiraly"] __all__: List[str] = ["BaseEstimator", "BaseObject"] @@ -446,7 +447,28 @@ def get_config(self): class attribute via nested inheritance and then any overrides and new tags from _onfig_dynamic object attribute. """ - return self._get_flags(flag_attr_name="_config") + config = get_config().copy() + + # Get any extension configuration interface defined in the class + # for example if downstream package wants to extend skbase to retrieve + # their own config + if hasattr(self, "__skbase_get_config__") and callable( + self.__skbase_get_config__ + ): + skbase_get_config_extension_dict = self.__skbase_get_config__() + else: + skbase_get_config_extension_dict = {} + if isinstance(skbase_get_config_extension_dict, dict): + config.update(skbase_get_config_extension_dict) + else: + msg = "Use of `__skbase_get_config__` to extend the interface for local " + msg += "overrides of the global configuration must return a dictionary.\n" + msg += f"But a {type(skbase_get_config_extension_dict)} was found." + warnings.warn(msg, UserWarning, stacklevel=2) + local_config = self._get_flags(flag_attr_name="_config") + config.update(local_config) + + return config def set_config(self, **config_dict): """Set config flags to given values. diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index 38d2d5a9..66dcf45b 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -69,6 +69,7 @@ import inspect from copy import deepcopy +from typing import Any, Dict, Type import numpy as np import pytest @@ -79,6 +80,7 @@ from sklearn.base import clone from skbase.base import BaseEstimator, BaseObject +from skbase.config import get_config from skbase.tests.conftest import Child, Parent from skbase.tests.mock_package.test_mock_package import CompositionDummy @@ -200,13 +202,13 @@ def fixture_reset_tester(): @pytest.fixture -def fixture_class_child_tags(fixture_class_child): +def fixture_class_child_tags(fixture_class_child: Type[Child]): """Pytest fixture for tags of Child.""" return fixture_class_child.get_class_tags() @pytest.fixture -def fixture_object_instance_set_tags(fixture_tag_class_object): +def fixture_object_instance_set_tags(fixture_tag_class_object: Child): """Fixture class instance to test tag setting.""" fixture_tag_set = {"A": 42424243, "E": 3} return fixture_tag_class_object.set_tags(**fixture_tag_set) @@ -266,7 +268,9 @@ def fixture_class_instance_no_param_interface(): return NoParamInterface() -def test_get_class_tags(fixture_class_child, fixture_class_child_tags): +def test_get_class_tags( + fixture_class_child: Type[Child], fixture_class_child_tags: Any +): """Test get_class_tags class method of BaseObject for correctness. Raises @@ -280,7 +284,7 @@ def test_get_class_tags(fixture_class_child, fixture_class_child_tags): assert child_tags == fixture_class_child_tags, msg -def test_get_class_tag(fixture_class_child, fixture_class_child_tags): +def test_get_class_tag(fixture_class_child: Type[Child], fixture_class_child_tags: Any): """Test get_class_tag class method of BaseObject for correctness. Raises @@ -307,7 +311,7 @@ def test_get_class_tag(fixture_class_child, fixture_class_child_tags): assert child_tag_default_none is None, msg -def test_get_tags(fixture_tag_class_object, fixture_object_tags): +def test_get_tags(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]): """Test get_tags method of BaseObject for correctness. Raises @@ -321,7 +325,7 @@ def test_get_tags(fixture_tag_class_object, fixture_object_tags): assert object_tags == fixture_object_tags, msg -def test_get_tag(fixture_tag_class_object, fixture_object_tags): +def test_get_tag(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]): """Test get_tag method of BaseObject for correctness. Raises @@ -351,7 +355,7 @@ def test_get_tag(fixture_tag_class_object, fixture_object_tags): assert object_tag_default_none is None, msg -def test_get_tag_raises(fixture_tag_class_object): +def test_get_tag_raises(fixture_tag_class_object: Child): """Test that get_tag method raises error for unknown tag. Raises @@ -363,9 +367,9 @@ def test_get_tag_raises(fixture_tag_class_object): def test_set_tags( - fixture_object_instance_set_tags, - fixture_object_set_tags, - fixture_object_dynamic_tags, + fixture_object_instance_set_tags: Any, + fixture_object_set_tags: Dict[str, Any], + fixture_object_dynamic_tags: Dict[str, int], ): """Test set_tags method of BaseObject for correctness. @@ -381,7 +385,9 @@ def test_set_tags( assert fixture_object_instance_set_tags.get_tags() == fixture_object_set_tags, msg -def test_set_tags_works_with_missing_tags_dynamic_attribute(fixture_tag_class_object): +def test_set_tags_works_with_missing_tags_dynamic_attribute( + fixture_tag_class_object: Child, +): """Test set_tags will still work if _tags_dynamic is missing.""" base_obj = deepcopy(fixture_tag_class_object) delattr(base_obj, "_tags_dynamic") @@ -460,7 +466,7 @@ class AnotherTestClass(BaseObject): assert test_obj_tags.get(tag) == another_base_obj_tags[tag] -def test_is_composite(fixture_composition_dummy): +def test_is_composite(fixture_composition_dummy: Type[CompositionDummy]): """Test is_composite tag for correctness. Raises @@ -474,7 +480,11 @@ def test_is_composite(fixture_composition_dummy): assert composite.is_composite() -def test_components(fixture_object, fixture_class_parent, fixture_composition_dummy): +def test_components( + fixture_object: Type[BaseObject], + fixture_class_parent: Type[Parent], + fixture_composition_dummy: Type[CompositionDummy], +): """Test component retrieval. Raises @@ -507,7 +517,7 @@ def test_components(fixture_object, fixture_class_parent, fixture_composition_du def test_components_raises_error_base_class_is_not_class( - fixture_object, fixture_composition_dummy + fixture_object: Type[BaseObject], fixture_composition_dummy: Type[CompositionDummy] ): """Test _component method raises error if base_class param is not class.""" non_composite = fixture_composition_dummy(foo=42) @@ -526,7 +536,7 @@ def test_components_raises_error_base_class_is_not_class( def test_components_raises_error_base_class_is_not_baseobject_subclass( - fixture_composition_dummy, + fixture_composition_dummy: Type[CompositionDummy], ): """Test _component method raises error if base_class is not BaseObject subclass.""" @@ -540,7 +550,7 @@ class SomeClass: # Test parameter interface (get_params, set_params, reset and related methods) # Some tests of get_params and set_params are adapted from sklearn tests -def test_reset(fixture_reset_tester): +def test_reset(fixture_reset_tester: Type[ResetTester]): """Test reset method for correct behaviour, on a simple estimator. Raises @@ -567,7 +577,7 @@ def test_reset(fixture_reset_tester): assert hasattr(x, "foo") -def test_reset_composite(fixture_reset_tester): +def test_reset_composite(fixture_reset_tester: Type[ResetTester]): """Test reset method for correct behaviour, on a composite estimator.""" y = fixture_reset_tester(42) x = fixture_reset_tester(a=y) @@ -582,7 +592,7 @@ def test_reset_composite(fixture_reset_tester): assert not hasattr(x.a, "d") -def test_get_init_signature(fixture_class_parent): +def test_get_init_signature(fixture_class_parent: Type[Parent]): """Test error is raised when invalid init signature is used.""" init_sig = fixture_class_parent._get_init_signature() init_sig_is_list = isinstance(init_sig, list) @@ -594,14 +604,18 @@ def test_get_init_signature(fixture_class_parent): ), "`_get_init_signature` is not returning expected result." -def test_get_init_signature_raises_error_for_invalid_signature(fixture_invalid_init): +def test_get_init_signature_raises_error_for_invalid_signature( + fixture_invalid_init: Type[InvalidInitSignatureTester], +): """Test error is raised when invalid init signature is used.""" with pytest.raises(RuntimeError): fixture_invalid_init._get_init_signature() def test_get_param_names( - fixture_object, fixture_class_parent, fixture_class_parent_expected_params + fixture_object: Type[BaseObject], + fixture_class_parent: Type[Parent], + fixture_class_parent_expected_params: Dict[str, Any], ): """Test that get_param_names returns list of string parameter names.""" param_names = fixture_class_parent.get_param_names() @@ -612,10 +626,10 @@ def test_get_param_names( def test_get_params( - fixture_class_parent, - fixture_class_parent_expected_params, - fixture_class_instance_no_param_interface, - fixture_composition_dummy, + fixture_class_parent: Type[Parent], + fixture_class_parent_expected_params: Dict[str, Any], + fixture_class_instance_no_param_interface: NoParamInterface, + fixture_composition_dummy: Type[CompositionDummy], ): """Test get_params returns expected parameters.""" # Simple test of returned params @@ -638,7 +652,10 @@ def test_get_params( assert "foo" in params and "bar" in params and len(params) == 2 -def test_get_params_invariance(fixture_class_parent, fixture_composition_dummy): +def test_get_params_invariance( + fixture_class_parent: Type[Parent], + fixture_composition_dummy: Type[CompositionDummy], +): """Test that get_params(deep=False) is subset of get_params(deep=True).""" composite = fixture_composition_dummy(foo=fixture_class_parent(), bar=84) shallow_params = composite.get_params(deep=False) @@ -646,7 +663,7 @@ def test_get_params_invariance(fixture_class_parent, fixture_composition_dummy): assert all(item in deep_params.items() for item in shallow_params.items()) -def test_get_params_after_set_params(fixture_class_parent): +def test_get_params_after_set_params(fixture_class_parent: Type[Parent]): """Test that get_params returns the same thing before and after set_params. Based on scikit-learn check in check_estimator. @@ -687,9 +704,9 @@ def test_get_params_after_set_params(fixture_class_parent): def test_set_params( - fixture_class_parent, - fixture_class_parent_expected_params, - fixture_composition_dummy, + fixture_class_parent: Type[Parent], + fixture_class_parent_expected_params: Dict[str, Any], + fixture_composition_dummy: Type[CompositionDummy], ): """Test set_params works as expected.""" # Simple case of setting a parameter @@ -711,7 +728,8 @@ def test_set_params( def test_set_params_raises_error_non_existent_param( - fixture_class_parent_instance, fixture_composition_dummy + fixture_class_parent_instance: Parent, + fixture_composition_dummy: Type[CompositionDummy], ): """Test set_params raises an error when passed a non-existent parameter name.""" # non-existing parameter in svc @@ -727,7 +745,8 @@ def test_set_params_raises_error_non_existent_param( def test_set_params_raises_error_non_interface_composite( - fixture_class_instance_no_param_interface, fixture_composition_dummy + fixture_class_instance_no_param_interface: NoParamInterface, + fixture_composition_dummy: Type[CompositionDummy], ): """Test set_params raises error when setting param of non-conforming composite.""" # When a composite is made up of a class that doesn't have the BaseObject @@ -753,7 +772,9 @@ def __init__(self, param=5): est.get_params() -def test_set_params_with_no_param_to_set_returns_object(fixture_class_parent): +def test_set_params_with_no_param_to_set_returns_object( + fixture_class_parent: Type[Parent], +): """Test set_params correctly returns self when no parameters are set.""" base_obj = fixture_class_parent() orig_params = deepcopy(base_obj.get_params()) @@ -767,7 +788,7 @@ def test_set_params_with_no_param_to_set_returns_object(fixture_class_parent): # This section tests the clone functionality # These have been adapted from sklearn's tests of clone to use the clone # method that is included as part of the BaseObject interface -def test_clone(fixture_class_parent_instance): +def test_clone(fixture_class_parent_instance: Parent): """Test that clone is making a deep copy as expected.""" # Creates a BaseObject and makes a copy of its original state # (which, in this case, is the current state of the BaseObject), @@ -777,7 +798,7 @@ def test_clone(fixture_class_parent_instance): assert fixture_class_parent_instance.get_params() == new_base_obj.get_params() -def test_clone_2(fixture_class_parent_instance): +def test_clone_2(fixture_class_parent_instance: Parent): """Test that clone does not copy attributes not set in constructor.""" # We first create an estimator, give it an own attribute, and # make a copy of its original state. Then we check that the copy doesn't @@ -790,7 +811,9 @@ def test_clone_2(fixture_class_parent_instance): def test_clone_raises_error_for_nonconforming_objects( - fixture_invalid_init, fixture_buggy, fixture_modify_param + fixture_invalid_init: Type[InvalidInitSignatureTester], + fixture_buggy: Type[Buggy], + fixture_modify_param: Type[ModifyParam], ): """Test that clone raises an error on nonconforming BaseObjects.""" buggy = fixture_buggy() @@ -807,7 +830,7 @@ def test_clone_raises_error_for_nonconforming_objects( obj_that_modifies.clone() -def test_clone_param_is_none(fixture_class_parent): +def test_clone_param_is_none(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter set to None.""" base_obj = fixture_class_parent(c=None) new_base_obj = clone(base_obj) @@ -816,7 +839,7 @@ def test_clone_param_is_none(fixture_class_parent): assert base_obj.c is new_base_obj2.c -def test_clone_empty_array(fixture_class_parent): +def test_clone_empty_array(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter is scipy sparse matrix. This test is based on scikit-learn regression test to make sure clone @@ -830,7 +853,7 @@ def test_clone_empty_array(fixture_class_parent): np.testing.assert_array_equal(base_obj.c, new_base_obj2.c) -def test_clone_sparse_matrix(fixture_class_parent): +def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter is scipy sparse matrix. This test is based on scikit-learn regression test to make sure clone @@ -843,7 +866,7 @@ def test_clone_sparse_matrix(fixture_class_parent): np.testing.assert_array_equal(base_obj.c, new_base_obj2.c) -def test_clone_nan(fixture_class_parent): +def test_clone_nan(fixture_class_parent: Type[Parent]): """Test clone with keyword parameter is np.nan. This test is based on scikit-learn regression test to make sure clone @@ -858,7 +881,7 @@ def test_clone_nan(fixture_class_parent): assert base_obj.c is new_base_obj2.c -def test_clone_estimator_types(fixture_class_parent): +def test_clone_estimator_types(fixture_class_parent: Type[Parent]): """Test clone works for parameters that are types rather than instances.""" base_obj = fixture_class_parent(c=fixture_class_parent) new_base_obj = base_obj.clone() @@ -866,7 +889,9 @@ def test_clone_estimator_types(fixture_class_parent): assert base_obj.c == new_base_obj.c -def test_clone_class_rather_than_instance_raises_error(fixture_class_parent): +def test_clone_class_rather_than_instance_raises_error( + fixture_class_parent: Type[Parent], +): """Test clone raises expected error when cloning a class instead of an instance.""" msg = "You should provide an instance of scikit-learn estimator" with pytest.raises(TypeError, match=msg): @@ -874,7 +899,10 @@ def test_clone_class_rather_than_instance_raises_error(fixture_class_parent): # Tests of BaseObject pretty printing representation inspired by sklearn -def test_baseobject_repr(fixture_class_parent, fixture_composition_dummy): +def test_baseobject_repr( + fixture_class_parent: Type[Parent], + fixture_composition_dummy: Type[CompositionDummy], +): """Test BaseObject repr works as expected.""" # Simple test where all parameters are left at defaults # Should not see parameters and values in printed representation @@ -893,12 +921,12 @@ def test_baseobject_repr(fixture_class_parent, fixture_composition_dummy): assert len(repr(long_base_obj_repr)) == 535 -def test_baseobject_str(fixture_class_parent_instance): +def test_baseobject_str(fixture_class_parent_instance: Parent): """Test BaseObject string representation works.""" str(fixture_class_parent_instance) -def test_baseobject_repr_mimebundle_(fixture_class_parent_instance): +def test_baseobject_repr_mimebundle_(fixture_class_parent_instance: Parent): """Test display configuration controls output.""" # Checks the display configuration flag controls the json output with config_context(display="diagram"): @@ -912,7 +940,7 @@ def test_baseobject_repr_mimebundle_(fixture_class_parent_instance): assert "text/html" not in output -def test_repr_html_wraps(fixture_class_parent_instance): +def test_repr_html_wraps(fixture_class_parent_instance: Parent): """Test display configuration flag controls the html output.""" with config_context(display="diagram"): output = fixture_class_parent_instance._repr_html_() @@ -925,20 +953,24 @@ def test_repr_html_wraps(fixture_class_parent_instance): # Test BaseObject's ability to generate test instances -def test_get_test_params(fixture_class_parent_instance): +def test_get_test_params(fixture_class_parent_instance: Parent): """Test get_test_params returns empty dictionary.""" base_obj = fixture_class_parent_instance test_params = base_obj.get_test_params() assert isinstance(test_params, dict) and len(test_params) == 0 -def test_get_test_params_raises_error_when_params_required(fixture_required_param): +def test_get_test_params_raises_error_when_params_required( + fixture_required_param: Type[RequiredParam], +): """Test get_test_params raises an error when parameters are required.""" with pytest.raises(ValueError): fixture_required_param(7).get_test_params() -def test_create_test_instance(fixture_class_parent, fixture_class_parent_instance): +def test_create_test_instance( + fixture_class_parent: Type[Parent], fixture_class_parent_instance: Parent +): """Test first that create_test_instance logic works.""" base_obj = fixture_class_parent.create_test_instance() @@ -957,7 +989,7 @@ def test_create_test_instance(fixture_class_parent, fixture_class_parent_instanc assert hasattr(base_obj, "_tags_dynamic"), msg -def test_create_test_instances_and_names(fixture_class_parent_instance): +def test_create_test_instances_and_names(fixture_class_parent_instance: Parent): """Test that create_test_instances_and_names works.""" base_objs, names = fixture_class_parent_instance.create_test_instances_and_names() @@ -990,7 +1022,7 @@ def test_create_test_instances_and_names(fixture_class_parent_instance): # Tests _has_implementation_of interface def test_has_implementation_of( - fixture_class_parent_instance, fixture_class_child_instance + fixture_class_parent_instance: Parent, fixture_class_child_instance: Child ): """Test _has_implementation_of detects methods in class with overrides in mro.""" # When the class overrides a parent classes method should return True @@ -1014,33 +1046,184 @@ def __init__(self, a, b=42): self.c = 84 -def test_set_get_config(): - """Test logic behind get_config, set_config. +class AnotherConfigTester(BaseObject): + _config = {"print_changed_only": False, "bar": "a"} - Raises - ------ - AssertionError if logic behind get_config, set_config is incorrect, logic tested: - calling get_fitted_params on a non-composite fittable returns the fitted param - calling get_fitted_params on a composite returns all nested params + clsvar = 210 + + def __init__(self, a, b=42): + self.a = a + self.b = b + self.c = 84 + + +class ConfigExtensionInterfaceTester(BaseObject): + _config = {"print_changed_only": False, "bar": "a"} + + clsvar = 210 + + def __init__(self, a, b=42): + self.a = a + self.b = b + self.c = 84 + + def __skbase_get_config__(self): + """Return get_config extension.""" + return {"print_changed_only": True, "some_other_config": 70} + + +def test_local_config_without_use_of_extension_interface(): + """Test ``BaseObject().get_config`` and ``BaseObject().set_config``. + + ``BaseObject.get_config()`` should return the global dict updated with any local + configs defined in ``BaseObject._config`` or set on the instance with + ``BaseObject().set_config()``. """ - obj = ConfigTester(4242) + # Initially test that We can retrieve the local config with local configs + # as defined in BaseObject._config - config_start = obj.get_config() - assert isinstance(config_start, dict) - assert set(config_start.keys()) == {"foo_config", "bar"} - assert config_start["foo_config"] == 42 - assert config_start["bar"] == "a" + # Case 1: local configs are not part of global config param names + obj = ConfigTester(4242) + obj_config = obj.get_config() + current_global_config = get_config().copy() + expected_global_config = set(current_global_config.keys()) + assert isinstance(obj_config, dict) + assert set(obj_config.keys()) == expected_global_config | {"foo_config", "bar"} + assert obj_config["foo_config"] == 42 + assert obj_config["bar"] == "a" + for param_name in current_global_config: + assert obj_config[param_name] == current_global_config[param_name] + + # Case 2: local configs overlap with (will override) global params + obj = AnotherConfigTester(4242) + obj_config = obj.get_config() + current_global_config = get_config().copy() + expected_global_config = set(current_global_config.keys()) + assert isinstance(obj_config, dict) + assert set(obj_config.keys()) == expected_global_config | {"bar"} + assert obj_config["bar"] == "a" + # Should have overrided global config value which is set to True + assert obj_config["print_changed_only"] is False + for param_name in current_global_config: + if param_name != "print_changed_only": + assert obj_config[param_name] == current_global_config[param_name] + + # Case 3: local configs are not part of global config param names and we also + # make use of dynamic BaseObject.set_config() + obj = ConfigTester(4242) + obj_config = obj.get_config() + current_global_config = get_config().copy() + expected_global_config = set(current_global_config.keys()) + # Verify set config returns the original object setconfig_return = obj.set_config(foobar=126) assert obj is setconfig_return obj.set_config(**{"bar": "b"}) - config_end = obj.get_config() - assert isinstance(config_end, dict) - assert set(config_end.keys()) == {"foo_config", "bar", "foobar"} - assert config_end["foo_config"] == 42 - assert config_end["bar"] == "b" - assert config_end["foobar"] == 126 + updated_obj_config = obj.get_config() + assert isinstance(updated_obj_config, dict) + assert set(updated_obj_config.keys()) == ( + expected_global_config | {"foo_config", "bar", "foobar"} + ) + assert updated_obj_config["foo_config"] == 42 + assert updated_obj_config["bar"] == "b" + assert updated_obj_config["foobar"] == 126 + + # Case 4: local configs are not part of global config param names and we also + # make use of dynamic BaseObject.set_config() to update a config that is also + # part of global config + obj = ConfigTester(4242) + obj_config = obj.get_config() + current_global_config = get_config().copy() + expected_global_config = set(current_global_config.keys()) + + # Verify set config returns the original object + setconfig_return = obj.set_config(print_changed_only=False) + assert obj is setconfig_return + + updated_obj_config = obj.get_config() + assert isinstance(updated_obj_config, dict) + assert set(updated_obj_config.keys()) == ( + expected_global_config | {"foo_config", "bar"} + ) + assert updated_obj_config["foo_config"] == 42 + assert updated_obj_config["bar"] == "a" + assert updated_obj_config["print_changed_only"] is False + for param_name in current_global_config: + if param_name != "print_changed_only": + assert updated_obj_config[param_name] == current_global_config[param_name] + + # Case 5: local configs overlap with (will override) global params + # Then the local config defined in AnotherConfigTester._config is overrode again + # by calling AnotherConfigTester().set_config() + obj = AnotherConfigTester(4242) + obj.set_config(print_changed_only=True) + obj_config = obj.get_config() + current_global_config = get_config().copy() + expected_global_config = set(current_global_config.keys()) + assert isinstance(obj_config, dict) + assert set(obj_config.keys()) == expected_global_config | {"bar"} + assert obj_config["bar"] == "a" + # Should have overrided global config value which is set to True + assert obj_config["print_changed_only"] is True + for param_name in current_global_config: + if param_name != "print_changed_only": + assert obj_config[param_name] == current_global_config[param_name] + + +def test_local_config_with_use_of_extension_interface(): + """Test BaseObject local config interface when ``__skbase_get_config__`` defined. + + BaseObject.get_config() should return the global dict updated in the following + order: + + - Any config returned by ``BaseObject.__skbase_get_config__``. + - Any configs defined in ``BaseObject._config`` or set on the instance with + ``BaseObject().set_config()``. + """ + current_global_config = get_config().copy() + obj = ConfigExtensionInterfaceTester(4242) + obj_config = obj.get_config() + assert "some_other_config" in obj_config + assert obj_config["print_changed_only"] is False + + expected_global_config = set(current_global_config.keys()) + assert isinstance(obj_config, dict) + assert set(obj_config.keys()) == expected_global_config | { + "some_other_config", + "bar", + } + assert obj_config["bar"] == "a" + # Should have overrided global config value which is set to True + + for param_name in current_global_config: + if param_name != "print_changed_only": + assert obj_config[param_name] == current_global_config[param_name] + + # Now lets verify we can override the config items only returned by + # __skbase_get_config__ extension interface + obj.set_config(some_other_config=22) + obj_config_updated = obj.get_config() + assert obj_config_updated["some_other_config"] == 22 + for param_name in obj_config: + if param_name != "some_other_config": + assert obj_config_updated[param_name] == obj_config[param_name] + + +def local_config_interface_does_not_affect_global_config_interface(): + """Test that calls to instance config interface doesn't impact global config.""" + from skbase.config import get_default_config, global_config + + obj = AnotherConfigTester(4242) + _global_config_before = global_config.copy() + _default_config_before = get_default_config() + global_config_before = get_config().copy() + obj.set_config(print_changed_only=False, some_other_param=7) + global_config_after = get_config().copy() + assert global_config_before == global_config_after + assert "some_other_param" not in global_config_after + assert _global_config_before == global_config + assert _default_config_before == get_default_config() class FittableCompositionDummy(BaseEstimator): From d89e7caec80a59db40a5dae657b30b351f3d0bc3 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 00:54:21 -0400 Subject: [PATCH 04/12] Add global config to docs --- docs/source/api_reference.rst | 21 +++++++++++++++++++ docs/source/user_documentation/user_guide.rst | 15 +++++++++++++ .../user_guide/configuration.rst | 11 ++++++++++ skbase/config/__init__.py | 11 ++++++++-- 4 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 docs/source/user_documentation/user_guide/configuration.rst diff --git a/docs/source/api_reference.rst b/docs/source/api_reference.rst index 03026bdf..9508d39e 100644 --- a/docs/source/api_reference.rst +++ b/docs/source/api_reference.rst @@ -28,6 +28,27 @@ Base Classes BaseObject BaseEstimator +.. _global_config: + +Configure ``skbase`` +==================== + +.. automodule:: skbase.config + :no-members: + :no-inherited-members: + +.. currentmodule:: skbase.config + +.. autosummary:: + :toctree: api_reference/auto_generated/ + :template: function.rst + + get_config + get_default_config + set_config + reset_config + config_context + .. _obj_retrieval: Object Retrieval diff --git a/docs/source/user_documentation/user_guide.rst b/docs/source/user_documentation/user_guide.rst index c8febd31..d571554b 100644 --- a/docs/source/user_documentation/user_guide.rst +++ b/docs/source/user_documentation/user_guide.rst @@ -29,6 +29,7 @@ that ``skbase`` provides, see the :ref:`api_ref`. user_guide/lookup user_guide/validate user_guide/testing + user_guide/configuration .. grid:: 1 2 2 2 @@ -103,3 +104,17 @@ that ``skbase`` provides, see the :ref:`api_ref`. :expand: Testing + + .. grid-item-card:: Configuration + :text-align: center + + Configure ``skbase``. + + +++ + + .. button-ref:: user_guide/testing + :color: primary + :click-parent: + :expand: + + Configuration diff --git a/docs/source/user_documentation/user_guide/configuration.rst b/docs/source/user_documentation/user_guide/configuration.rst new file mode 100644 index 00000000..05b76843 --- /dev/null +++ b/docs/source/user_documentation/user_guide/configuration.rst @@ -0,0 +1,11 @@ +.. _user_guide_global_config: + +==================== +Configure ``skbase`` +==================== + +.. note:: + + The user guide is under development. We have created a basic + structure and are looking for contributions to develop the user guide + further. diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py index 55d58a3c..b8c8d3c0 100644 --- a/skbase/config/__init__.py +++ b/skbase/config/__init__.py @@ -1,5 +1,9 @@ # -*- coding: utf-8 -*- -""":mod:`skbase.config` provides tools for global configuration of ``skbase``.""" +""":mod:`skbase.config` provides tools for the global configuration of ``skbase``. + +For more information on configuration usage patterns see the +:ref:`user guide `. +""" # -*- coding: utf-8 -*- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file) # Includes functionality like get_config, set_config, and config_context @@ -167,7 +171,7 @@ def get_default_config() -> Dict[str, Any]: def get_config() -> Dict[str, Any]: """Retrieve current values for configuration set by :meth:`set_config`. - Willr return the default configuration if know updated configuration has + Will return the default configuration if know updated configuration has been set by :meth:`set_config`. Returns @@ -309,6 +313,9 @@ def config_context( ) -> Iterator[None]: """Context manager for global configuration. + Provides the ability to run code using different configuration without + having to update the global config. + Parameters ---------- print_changed_only : bool, default=None From 68560d36dc47d5cbc7ac2a17dd37f839d7e2f56e Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 00:59:20 -0400 Subject: [PATCH 05/12] Fix user_guide panel for configuration --- docs/source/user_documentation/user_guide.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user_documentation/user_guide.rst b/docs/source/user_documentation/user_guide.rst index d571554b..4e15f3ec 100644 --- a/docs/source/user_documentation/user_guide.rst +++ b/docs/source/user_documentation/user_guide.rst @@ -112,7 +112,7 @@ that ``skbase`` provides, see the :ref:`api_ref`. +++ - .. button-ref:: user_guide/testing + .. button-ref:: user_guide/configuration :color: primary :click-parent: :expand: From ddaa90e03d0bd7fd424817527ea2b8ccc065867e Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 01:01:24 -0400 Subject: [PATCH 06/12] Fix docstring typo --- skbase/config/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py index b8c8d3c0..a7ca56b8 100644 --- a/skbase/config/__init__.py +++ b/skbase/config/__init__.py @@ -139,7 +139,7 @@ def _get_threadlocal_config() -> Dict[str, Any]: def get_default_config() -> Dict[str, Any]: - """Retrive the default global configuration. + """Retrieve the default global configuration. This will always return the default ``skbase`` global configuration. From eeee466a7bb52ae725c3f27a2cf965cab19715dc Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 02:01:51 -0400 Subject: [PATCH 07/12] Tweak behavior when invalid values encountered --- skbase/base/_base.py | 15 +++++++-- skbase/config/__init__.py | 52 +++++++++++++++++++++--------- skbase/config/tests/test_config.py | 8 +++-- skbase/tests/test_base.py | 1 - 4 files changed, 55 insertions(+), 21 deletions(-) diff --git a/skbase/base/_base.py b/skbase/base/_base.py index a7216434..ae10cf29 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -63,7 +63,7 @@ class name: BaseEstimator from skbase._exceptions import NotFittedError from skbase.base._tagmanager import _FlagManager -from skbase.config import get_config +from skbase.config import _CONFIG_REGISTRY, get_config __author__: List[str] = ["mloning", "RNKuhns", "fkiraly"] __all__: List[str] = ["BaseEstimator", "BaseObject"] @@ -465,7 +465,18 @@ class attribute via nested inheritance and then any overrides msg += "overrides of the global configuration must return a dictionary.\n" msg += f"But a {type(skbase_get_config_extension_dict)} was found." warnings.warn(msg, UserWarning, stacklevel=2) - local_config = self._get_flags(flag_attr_name="_config") + local_config = self._get_flags(flag_attr_name="_config").copy() + # IF the local config is one of + for config_param, config_value in local_config.items(): + if config_param in _CONFIG_REGISTRY: + msg = "Invalid value encountered for global configuration parameter " + msg += f"{config_param}. Using global parameter configuration value.\n" + config_value = _CONFIG_REGISTRY[ + config_param + ].get_valid_param_or_default( + config_value, default_value=config[config_param] + ) + local_config[config_param] = config_value config.update(local_config) return config diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py index a7ca56b8..2d1abfde 100644 --- a/skbase/config/__init__.py +++ b/skbase/config/__init__.py @@ -35,10 +35,10 @@ class GlobalConfigParamSetting: name: str os_environ_name: str expected_type: Union[type, Tuple[type]] - allowed_values: Optional[Tuple[Any, ...]] + allowed_values: Optional[Union[Tuple[Any, ...], List[Any]]] default_value: Any - def get_allowed_values(self): + def get_allowed_values(self) -> List[Any]: """Get `allowed_values` or empty tuple if `allowed_values` is None. Returns @@ -47,15 +47,15 @@ def get_allowed_values(self): Allowable values if any. """ if self.allowed_values is None: - return () - elif isinstance(self.allowed_values, tuple): + return [] + elif isinstance(self.allowed_values, list): return self.allowed_values elif isinstance( self.allowed_values, collections.abc.Iterable ) and not isinstance(self.allowed_values, str): - return tuple(self.allowed_values) + return list(self.allowed_values) else: - return (self.allowed_values,) + return [self.allowed_values] def is_valid_param_value(self, value): """Validate that a global configuration value is valid. @@ -79,22 +79,36 @@ def is_valid_param_value(self, value): valid_param = True return valid_param - def get_valid_param_or_default(self, value): - """Validate `value` and return default if it is not valid.""" + def get_valid_param_or_default(self, value, default_value=None, msg=None): + """Validate `value` and return default if it is not valid. + + Parameters + ---------- + value : Any + The configuration parameter value to set. + default_value : Any, default=None + An optional default value to use to set the configuration parameter + if `value` is not valid based on defined expected type and allowed + values. If None, and `value` is invalid then the classes `default_value` + parameter is used. + msg : str, default=None + An optional message to be used as start of the UserWarning message. + """ if self.is_valid_param_value(value): return value else: - msg = "Attempting to set an invalid value for a global configuration.\n" - msg += "Using default configuration value of parameter as a result.\n" + if msg is None: + msg = "" msg + f"When setting global config values for `{self.name}`, the values " - msg += f"should be of type {self.expected_type}." + msg += f"should be of type {self.expected_type}.\n" if self.allowed_values is not None: values_str = _format_seq_to_str( self.get_allowed_values(), last_sep="or", remove_type_text=True ) - msg += f"Allowed values are be one of {values_str}." + msg += f"Allowed values are be one of {values_str}. " + msg += f"But found {value}." warnings.warn(msg, UserWarning, stacklevel=2) - return self.default_value + return default_value if default_value is not None else self.default_value _CONFIG_REGISTRY: Dict[str, GlobalConfigParamSetting] = { @@ -251,14 +265,22 @@ def set_config( """ local_config = _get_threadlocal_config() + msg = "Attempting to set an invalid value for a global configuration.\n" + msg += "Using current configuration value of parameter as a result.\n" if print_changed_only is not None: local_config["print_changed_only"] = _CONFIG_REGISTRY[ "print_changed_only" - ].get_valid_param_or_default(print_changed_only) + ].get_valid_param_or_default( + print_changed_only, + default_value=local_config["print_changed_only"], + msg=msg, + ) if display is not None: local_config["display"] = _CONFIG_REGISTRY[ "display" - ].get_valid_param_or_default(display) + ].get_valid_param_or_default( + display, default_value=local_config["display"], msg=msg + ) if not local_threadsafe: global_config.update(local_config) diff --git a/skbase/config/tests/test_config.py b/skbase/config/tests/test_config.py index 11f3b5cf..0569125e 100644 --- a/skbase/config/tests/test_config.py +++ b/skbase/config/tests/test_config.py @@ -41,7 +41,7 @@ def test_global_config_param_get_allowed_values(allowed_values): ) # Verify we always coerce output of get_allowed_values to tuple values = some_config_param.get_allowed_values() - assert isinstance(values, tuple) + assert isinstance(values, list) @pytest.mark.parametrize("value", (None, (), "wrong_string", "text", range(1, 8))) @@ -130,13 +130,15 @@ def test_config_context(print_changed_only, display): def test_set_config_behavior_invalid_value(): """Test set_config uses default and raises warning when setting invalid value.""" reset_config() + original_config = get_config().copy() with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"): set_config(print_changed_only="False") - assert get_config() == get_default_config() + assert get_config() == original_config + original_config = get_config().copy() with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"): set_config(print_changed_only=7) - assert get_config() == get_default_config() + assert get_config() == original_config reset_config() diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index 66dcf45b..f934b2c2 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -1133,7 +1133,6 @@ def test_local_config_without_use_of_extension_interface(): # make use of dynamic BaseObject.set_config() to update a config that is also # part of global config obj = ConfigTester(4242) - obj_config = obj.get_config() current_global_config = get_config().copy() expected_global_config = set(current_global_config.keys()) From 91a72902888e035ce7de6eac049df071ea2eebfa Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 02:08:52 -0400 Subject: [PATCH 08/12] Remove unused code line --- skbase/tests/test_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index f934b2c2..27ca97bf 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -1111,7 +1111,6 @@ def test_local_config_without_use_of_extension_interface(): # Case 3: local configs are not part of global config param names and we also # make use of dynamic BaseObject.set_config() obj = ConfigTester(4242) - obj_config = obj.get_config() current_global_config = get_config().copy() expected_global_config = set(current_global_config.keys()) From 289d0ab3604517bb345055d151c997ef668393c9 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Mon, 13 Mar 2023 02:23:04 -0400 Subject: [PATCH 09/12] Move config implementation out of __init__ --- skbase/base/_base.py | 3 +- skbase/config/__init__.py | 395 ++-------------------------- skbase/config/_config.py | 397 +++++++++++++++++++++++++++++ skbase/config/tests/test_config.py | 3 +- skbase/tests/conftest.py | 11 +- 5 files changed, 427 insertions(+), 382 deletions(-) create mode 100644 skbase/config/_config.py diff --git a/skbase/base/_base.py b/skbase/base/_base.py index ae10cf29..455782a1 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -63,7 +63,8 @@ class name: BaseEstimator from skbase._exceptions import NotFittedError from skbase.base._tagmanager import _FlagManager -from skbase.config import _CONFIG_REGISTRY, get_config +from skbase.config import get_config +from skbase.config._config import _CONFIG_REGISTRY __author__: List[str] = ["mloning", "RNKuhns", "fkiraly"] __all__: List[str] = ["BaseEstimator", "BaseObject"] diff --git a/skbase/config/__init__.py b/skbase/config/__init__.py index 2d1abfde..fd8a750a 100644 --- a/skbase/config/__init__.py +++ b/skbase/config/__init__.py @@ -10,384 +10,23 @@ # that is similar to scikit-learn. These elements are copyrighted by the # scikit-learn developers, BSD-3-Clause License. For conditions see # https://github.com/scikit-learn/scikit-learn/blob/main/COPYING -import collections -import sys -import threading -import warnings -from contextlib import contextmanager -from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import List -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal # type: ignore - -from skbase.utils._iter import _format_seq_to_str +from skbase.config._config import ( + GlobalConfigParamSetting, + config_context, + get_config, + get_default_config, + reset_config, + set_config, +) __author__: List[str] = ["RNKuhns"] - - -@dataclass -class GlobalConfigParamSetting: - """Define types of the setting information for a given config parameter.""" - - name: str - os_environ_name: str - expected_type: Union[type, Tuple[type]] - allowed_values: Optional[Union[Tuple[Any, ...], List[Any]]] - default_value: Any - - def get_allowed_values(self) -> List[Any]: - """Get `allowed_values` or empty tuple if `allowed_values` is None. - - Returns - ------- - tuple - Allowable values if any. - """ - if self.allowed_values is None: - return [] - elif isinstance(self.allowed_values, list): - return self.allowed_values - elif isinstance( - self.allowed_values, collections.abc.Iterable - ) and not isinstance(self.allowed_values, str): - return list(self.allowed_values) - else: - return [self.allowed_values] - - def is_valid_param_value(self, value): - """Validate that a global configuration value is valid. - - Verifies that the value set for a global configuration parameter is valid - based on the its configuration settings. - - Returns - ------- - bool - Whether a parameter value is valid. - """ - allowed_values = self.get_allowed_values() - - valid_param: bool - if not isinstance(value, self.expected_type): - valid_param = False - elif allowed_values is not None and value not in allowed_values: - valid_param = False - else: - valid_param = True - return valid_param - - def get_valid_param_or_default(self, value, default_value=None, msg=None): - """Validate `value` and return default if it is not valid. - - Parameters - ---------- - value : Any - The configuration parameter value to set. - default_value : Any, default=None - An optional default value to use to set the configuration parameter - if `value` is not valid based on defined expected type and allowed - values. If None, and `value` is invalid then the classes `default_value` - parameter is used. - msg : str, default=None - An optional message to be used as start of the UserWarning message. - """ - if self.is_valid_param_value(value): - return value - else: - if msg is None: - msg = "" - msg + f"When setting global config values for `{self.name}`, the values " - msg += f"should be of type {self.expected_type}.\n" - if self.allowed_values is not None: - values_str = _format_seq_to_str( - self.get_allowed_values(), last_sep="or", remove_type_text=True - ) - msg += f"Allowed values are be one of {values_str}. " - msg += f"But found {value}." - warnings.warn(msg, UserWarning, stacklevel=2) - return default_value if default_value is not None else self.default_value - - -_CONFIG_REGISTRY: Dict[str, GlobalConfigParamSetting] = { - "print_changed_only": GlobalConfigParamSetting( - name="print_changed_only", - os_environ_name="SKBASE_PRINT_CHANGED_ONLY", - expected_type=bool, - allowed_values=(True, False), - default_value=True, - ), - "display": GlobalConfigParamSetting( - name="display", - os_environ_name="SKBASE_OBJECT_DISPLAY", - expected_type=str, - allowed_values=("text", "diagram"), - default_value="text", - ), -} - -_DEFAULT_GLOBAL_CONFIG: Dict[str, Any] = { - config_name: config_info.default_value - for config_name, config_info in _CONFIG_REGISTRY.items() -} - -global_config = _DEFAULT_GLOBAL_CONFIG.copy() -_THREAD_LOCAL_DATA = threading.local() - - -def _get_threadlocal_config() -> Dict[str, Any]: - """Get a threadlocal **mutable** configuration. - - If the configuration does not exist, copy the default global configuration. - - Returns - ------- - dict - Threadlocal global config or copy of default global configuration. - """ - if not hasattr(_THREAD_LOCAL_DATA, "global_config"): - _THREAD_LOCAL_DATA.global_config = global_config.copy() - return _THREAD_LOCAL_DATA.global_config - - -def get_default_config() -> Dict[str, Any]: - """Retrieve the default global configuration. - - This will always return the default ``skbase`` global configuration. - - Returns - ------- - config : dict - The default configurable settings (keys) and their default values (values). - - See Also - -------- - config_context : - Configuration context manager. - get_config : - Retrieve current global configuration values. - set_config : - Set global configuration. - reset_config : - Reset configuration to ``skbase`` default. - - Examples - -------- - >>> from skbase.config import get_default_config - >>> get_default_config() - {'print_changed_only': True, 'display': 'text'} - """ - return _DEFAULT_GLOBAL_CONFIG.copy() - - -def get_config() -> Dict[str, Any]: - """Retrieve current values for configuration set by :meth:`set_config`. - - Will return the default configuration if know updated configuration has - been set by :meth:`set_config`. - - Returns - ------- - config : dict - The configurable settings (keys) and their default values (values). - - See Also - -------- - config_context : - Configuration context manager. - get_default_config : - Retrieve ``skbase``'s default configuration. - set_config : - Set global configuration. - reset_config : - Reset configuration to ``skbase`` default. - - Examples - -------- - >>> from skbase.config import get_config - >>> get_config() - {'print_changed_only': True, 'display': 'text'} - """ - return _get_threadlocal_config().copy() - - -def set_config( - print_changed_only: Optional[bool] = None, - display: Literal["text", "diagram"] = None, - local_threadsafe: bool = False, -) -> None: - """Set global configuration. - - Allows the ``skbase`` global configuration to be updated. - - Parameters - ---------- - print_changed_only : bool, default=None - If True, only the parameters that were set to non-default - values will be printed when printing a BaseObject instance. For example, - ``print(SVC())`` while True will only print 'SVC()', but would print - 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters - when False. If None, the existing value won't change. - display : {'text', 'diagram'}, default=None - If 'diagram', instances inheritting from BaseOBject will be displayed - as a diagram in a Jupyter lab or notebook context. If 'text', instances - inheritting from BaseObject will be displayed as text. If None, the - existing value won't change. - local_threadsafe : bool, default=False - If False, set the backend as default for all threads. - - Returns - ------- - None - No output returned. - - See Also - -------- - config_context : - Configuration context manager. - get_default_config : - Retrieve ``skbase``'s default configuration. - get_config : - Retrieve current global configuration values. - reset_config : - Reset configuration to default. - - Examples - -------- - >>> from skbase.config import get_config, set_config - >>> get_config() - {'print_changed_only': True, 'display': 'text'} - >>> set_config(display='diagram') - >>> get_config() - {'print_changed_only': True, 'display': 'diagram'} - """ - local_config = _get_threadlocal_config() - - msg = "Attempting to set an invalid value for a global configuration.\n" - msg += "Using current configuration value of parameter as a result.\n" - if print_changed_only is not None: - local_config["print_changed_only"] = _CONFIG_REGISTRY[ - "print_changed_only" - ].get_valid_param_or_default( - print_changed_only, - default_value=local_config["print_changed_only"], - msg=msg, - ) - if display is not None: - local_config["display"] = _CONFIG_REGISTRY[ - "display" - ].get_valid_param_or_default( - display, default_value=local_config["display"], msg=msg - ) - - if not local_threadsafe: - global_config.update(local_config) - - return None - - -def reset_config() -> None: - """Reset the global configuration to the default. - - Will remove any user updates to the global configuration and reset the values - back to the ``skbase`` defaults. - - Returns - ------- - None - No output returned. - - See Also - -------- - config_context : - Configuration context manager. - get_default_config : - Retrieve ``skbase``'s default configuration. - get_config : - Retrieve current global configuration values. - set_config : - Set global configuration. - - Examples - -------- - >>> from skbase.config import get_config, set_config, reset_config - >>> get_config() - {'print_changed_only': True, 'display': 'text'} - >>> set_config(display='diagram') - >>> get_config() - {'print_changed_only': True, 'display': 'diagram'} - >>> reset_config() - >>> get_config() - {'print_changed_only': True, 'display': 'text'} - """ - default_config = get_default_config() - set_config(**default_config) - return None - - -@contextmanager -def config_context( - print_changed_only: Optional[bool] = None, - display: Literal["text", "diagram"] = None, - local_threadsafe: bool = False, -) -> Iterator[None]: - """Context manager for global configuration. - - Provides the ability to run code using different configuration without - having to update the global config. - - Parameters - ---------- - print_changed_only : bool, default=None - If True, only the parameters that were set to non-default - values will be printed when printing a BaseObject instance. For example, - ``print(SVC())`` while True will only print 'SVC()', but would print - 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters - when False. If None, the existing value won't change. - display : {'text', 'diagram'}, default=None - If 'diagram', instances inheritting from BaseOBject will be displayed - as a diagram in a Jupyter lab or notebook context. If 'text', instances - inheritting from BaseObject will be displayed as text. If None, the - existing value won't change. - local_threadsafe : bool, default=False - If False, set the config as default for all threads. - - Yields - ------ - None - - See Also - -------- - get_default_config : - Retrieve ``skbase``'s default configuration. - get_config : - Retrieve current values of the global configuration. - set_config : - Set global configuration. - reset_config : - Reset configuration to ``skbase`` default. - - Notes - ----- - All settings, not just those presently modified, will be returned to - their previous values when the context manager is exited. - - Examples - -------- - >>> from skbase.config import config_context - >>> with config_context(display='diagram'): - ... pass - """ - old_config = get_config() - set_config( - print_changed_only=print_changed_only, - display=display, - local_threadsafe=local_threadsafe, - ) - - try: - yield - finally: - set_config(**old_config) +__all__: List[str] = [ + "GlobalConfigParamSetting", + "get_default_config", + "get_config", + "set_config", + "reset_config", + "config_context", +] diff --git a/skbase/config/_config.py b/skbase/config/_config.py new file mode 100644 index 00000000..ec2fc13c --- /dev/null +++ b/skbase/config/_config.py @@ -0,0 +1,397 @@ +# -*- coding: utf-8 -*- +"""Implement logic for global configuration of skbase.""" +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Includes functionality like get_config, set_config, and config_context +# that is similar to scikit-learn. These elements are copyrighted by the +# scikit-learn developers, BSD-3-Clause License. For conditions see +# https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +import collections +import sys +import threading +import warnings +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal # type: ignore + +from skbase.utils._iter import _format_seq_to_str + +__author__: List[str] = ["RNKuhns"] +__all__: List[str] = [ + "GlobalConfigParamSetting", + "get_default_config", + "get_config", + "set_config", + "reset_config", + "config_context", +] + + +@dataclass +class GlobalConfigParamSetting: + """Define types of the setting information for a given config parameter.""" + + name: str + os_environ_name: str + expected_type: Union[type, Tuple[type]] + allowed_values: Optional[Union[Tuple[Any, ...], List[Any]]] + default_value: Any + + def get_allowed_values(self) -> List[Any]: + """Get `allowed_values` or empty tuple if `allowed_values` is None. + + Returns + ------- + tuple + Allowable values if any. + """ + if self.allowed_values is None: + return [] + elif isinstance(self.allowed_values, list): + return self.allowed_values + elif isinstance( + self.allowed_values, collections.abc.Iterable + ) and not isinstance(self.allowed_values, str): + return list(self.allowed_values) + else: + return [self.allowed_values] + + def is_valid_param_value(self, value): + """Validate that a global configuration value is valid. + + Verifies that the value set for a global configuration parameter is valid + based on the its configuration settings. + + Returns + ------- + bool + Whether a parameter value is valid. + """ + allowed_values = self.get_allowed_values() + + valid_param: bool + if not isinstance(value, self.expected_type): + valid_param = False + elif allowed_values is not None and value not in allowed_values: + valid_param = False + else: + valid_param = True + return valid_param + + def get_valid_param_or_default(self, value, default_value=None, msg=None): + """Validate `value` and return default if it is not valid. + + Parameters + ---------- + value : Any + The configuration parameter value to set. + default_value : Any, default=None + An optional default value to use to set the configuration parameter + if `value` is not valid based on defined expected type and allowed + values. If None, and `value` is invalid then the classes `default_value` + parameter is used. + msg : str, default=None + An optional message to be used as start of the UserWarning message. + """ + if self.is_valid_param_value(value): + return value + else: + if msg is None: + msg = "" + msg + f"When setting global config values for `{self.name}`, the values " + msg += f"should be of type {self.expected_type}.\n" + if self.allowed_values is not None: + values_str = _format_seq_to_str( + self.get_allowed_values(), last_sep="or", remove_type_text=True + ) + msg += f"Allowed values are be one of {values_str}. " + msg += f"But found {value}." + warnings.warn(msg, UserWarning, stacklevel=2) + return default_value if default_value is not None else self.default_value + + +_CONFIG_REGISTRY: Dict[str, GlobalConfigParamSetting] = { + "print_changed_only": GlobalConfigParamSetting( + name="print_changed_only", + os_environ_name="SKBASE_PRINT_CHANGED_ONLY", + expected_type=bool, + allowed_values=(True, False), + default_value=True, + ), + "display": GlobalConfigParamSetting( + name="display", + os_environ_name="SKBASE_OBJECT_DISPLAY", + expected_type=str, + allowed_values=("text", "diagram"), + default_value="text", + ), +} + +_DEFAULT_GLOBAL_CONFIG: Dict[str, Any] = { + config_name: config_info.default_value + for config_name, config_info in _CONFIG_REGISTRY.items() +} + +global_config = _DEFAULT_GLOBAL_CONFIG.copy() +_THREAD_LOCAL_DATA = threading.local() + + +def _get_threadlocal_config() -> Dict[str, Any]: + """Get a threadlocal **mutable** configuration. + + If the configuration does not exist, copy the default global configuration. + + Returns + ------- + dict + Threadlocal global config or copy of default global configuration. + """ + if not hasattr(_THREAD_LOCAL_DATA, "global_config"): + _THREAD_LOCAL_DATA.global_config = global_config.copy() + return _THREAD_LOCAL_DATA.global_config + + +def get_default_config() -> Dict[str, Any]: + """Retrieve the default global configuration. + + This will always return the default ``skbase`` global configuration. + + Returns + ------- + config : dict + The default configurable settings (keys) and their default values (values). + + See Also + -------- + config_context : + Configuration context manager. + get_config : + Retrieve current global configuration values. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Examples + -------- + >>> from skbase.config import get_default_config + >>> get_default_config() + {'print_changed_only': True, 'display': 'text'} + """ + return _DEFAULT_GLOBAL_CONFIG.copy() + + +def get_config() -> Dict[str, Any]: + """Retrieve current values for configuration set by :meth:`set_config`. + + Will return the default configuration if know updated configuration has + been set by :meth:`set_config`. + + Returns + ------- + config : dict + The configurable settings (keys) and their default values (values). + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Examples + -------- + >>> from skbase.config import get_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + """ + return _get_threadlocal_config().copy() + + +def set_config( + print_changed_only: Optional[bool] = None, + display: Literal["text", "diagram"] = None, + local_threadsafe: bool = False, +) -> None: + """Set global configuration. + + Allows the ``skbase`` global configuration to be updated. + + Parameters + ---------- + print_changed_only : bool, default=None + If True, only the parameters that were set to non-default + values will be printed when printing a BaseObject instance. For example, + ``print(SVC())`` while True will only print 'SVC()', but would print + 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters + when False. If None, the existing value won't change. + display : {'text', 'diagram'}, default=None + If 'diagram', instances inheritting from BaseOBject will be displayed + as a diagram in a Jupyter lab or notebook context. If 'text', instances + inheritting from BaseObject will be displayed as text. If None, the + existing value won't change. + local_threadsafe : bool, default=False + If False, set the backend as default for all threads. + + Returns + ------- + None + No output returned. + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current global configuration values. + reset_config : + Reset configuration to default. + + Examples + -------- + >>> from skbase.config import get_config, set_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + >>> set_config(display='diagram') + >>> get_config() + {'print_changed_only': True, 'display': 'diagram'} + """ + local_config = _get_threadlocal_config() + + msg = "Attempting to set an invalid value for a global configuration.\n" + msg += "Using current configuration value of parameter as a result.\n" + if print_changed_only is not None: + local_config["print_changed_only"] = _CONFIG_REGISTRY[ + "print_changed_only" + ].get_valid_param_or_default( + print_changed_only, + default_value=local_config["print_changed_only"], + msg=msg, + ) + if display is not None: + local_config["display"] = _CONFIG_REGISTRY[ + "display" + ].get_valid_param_or_default( + display, default_value=local_config["display"], msg=msg + ) + + if not local_threadsafe: + global_config.update(local_config) + + return None + + +def reset_config() -> None: + """Reset the global configuration to the default. + + Will remove any user updates to the global configuration and reset the values + back to the ``skbase`` defaults. + + Returns + ------- + None + No output returned. + + See Also + -------- + config_context : + Configuration context manager. + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current global configuration values. + set_config : + Set global configuration. + + Examples + -------- + >>> from skbase.config import get_config, set_config, reset_config + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + >>> set_config(display='diagram') + >>> get_config() + {'print_changed_only': True, 'display': 'diagram'} + >>> reset_config() + >>> get_config() + {'print_changed_only': True, 'display': 'text'} + """ + default_config = get_default_config() + set_config(**default_config) + return None + + +@contextmanager +def config_context( + print_changed_only: Optional[bool] = None, + display: Literal["text", "diagram"] = None, + local_threadsafe: bool = False, +) -> Iterator[None]: + """Context manager for global configuration. + + Provides the ability to run code using different configuration without + having to update the global config. + + Parameters + ---------- + print_changed_only : bool, default=None + If True, only the parameters that were set to non-default + values will be printed when printing a BaseObject instance. For example, + ``print(SVC())`` while True will only print 'SVC()', but would print + 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters + when False. If None, the existing value won't change. + display : {'text', 'diagram'}, default=None + If 'diagram', instances inheritting from BaseOBject will be displayed + as a diagram in a Jupyter lab or notebook context. If 'text', instances + inheritting from BaseObject will be displayed as text. If None, the + existing value won't change. + local_threadsafe : bool, default=False + If False, set the config as default for all threads. + + Yields + ------ + None + + See Also + -------- + get_default_config : + Retrieve ``skbase``'s default configuration. + get_config : + Retrieve current values of the global configuration. + set_config : + Set global configuration. + reset_config : + Reset configuration to ``skbase`` default. + + Notes + ----- + All settings, not just those presently modified, will be returned to + their previous values when the context manager is exited. + + Examples + -------- + >>> from skbase.config import config_context + >>> with config_context(display='diagram'): + ... pass + """ + old_config = get_config() + set_config( + print_changed_only=print_changed_only, + display=display, + local_threadsafe=local_threadsafe, + ) + + try: + yield + finally: + set_config(**old_config) diff --git a/skbase/config/tests/test_config.py b/skbase/config/tests/test_config.py index 0569125e..1af163e4 100644 --- a/skbase/config/tests/test_config.py +++ b/skbase/config/tests/test_config.py @@ -3,8 +3,6 @@ import pytest from skbase.config import ( - _CONFIG_REGISTRY, - _DEFAULT_GLOBAL_CONFIG, GlobalConfigParamSetting, config_context, get_config, @@ -12,6 +10,7 @@ reset_config, set_config, ) +from skbase.config._config import _CONFIG_REGISTRY, _DEFAULT_GLOBAL_CONFIG PRINT_CHANGE_ONLY_VALUES = _CONFIG_REGISTRY["print_changed_only"].get_allowed_values() DISPLAY_VALUES = _CONFIG_REGISTRY["display"].get_allowed_values() diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index f4e0e847..a39e2251 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -24,6 +24,7 @@ "skbase.base._meta", "skbase.base._tagmanager", "skbase.config", + "skbase.config._config", "skbase.lookup", "skbase.lookup.tests", "skbase.lookup.tests.test_lookup", @@ -77,6 +78,7 @@ "skbase.base._base": ("BaseEstimator", "BaseObject"), "skbase.base._meta": ("BaseMetaEstimator",), "skbase.config": ("GlobalConfigParamSetting",), + "skbase.config._config": ("GlobalConfigParamSetting",), "skbase.lookup._lookup": ("ClassInfo", "FunctionInfo", "ModuleInfo"), "skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"), "skbase.testing.test_all_objects": ( @@ -99,6 +101,13 @@ "reset_config", "config_context", ), + "skbase.config._config": ( + "get_config", + "get_default_config", + "set_config", + "reset_config", + "config_context", + ), "skbase.lookup": ("all_objects", "get_package_metadata"), "skbase.lookup._lookup": ("all_objects", "get_package_metadata"), "skbase.testing.utils._conditional_fixtures": ( @@ -133,7 +142,7 @@ SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy() SKBASE_FUNCTIONS_BY_MODULE.update( { - "skbase.config": ( + "skbase.config._config": ( "_get_threadlocal_config", "get_config", "get_default_config", From 112d78a4f2543a2adf47d74979bf2bf07df03840 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Wed, 15 Mar 2023 20:53:38 -0400 Subject: [PATCH 10/12] Update docstring and comments around BaseObject.get_config --- skbase/base/_base.py | 37 +++++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/skbase/base/_base.py b/skbase/base/_base.py index 455782a1..29fce57c 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -439,35 +439,56 @@ def clone_tags(self, estimator, tag_names=None): return self def get_config(self): - """Get config flags for self. + """Get configuration parameters impacting the object. + + The configuration is retrieved in the following order: + + - ``skbase`` global configuration, + - downstream package configurations, and + - local configuration set via the object's _config class variable + or the object's `set_config` parameter. Returns ------- config_dict : dict - Dictionary of config name : config value pairs. Collected from _config - class attribute via nested inheritance and then any overrides - and new tags from _onfig_dynamic object attribute. + Dictionary of config name : config value pairs. """ + # Configuration is collected in a specific order from the farthest to + # most local. skbase global config -> downstream package (optional) -> local + # Start by collecting skbase's global config config = get_config().copy() - # Get any extension configuration interface defined in the class - # for example if downstream package wants to extend skbase to retrieve - # their own config + # Use the object config extension interface to optionally retrieve the + # configuration of any downstream package that is subclassing BaseObject if hasattr(self, "__skbase_get_config__") and callable( self.__skbase_get_config__ ): skbase_get_config_extension_dict = self.__skbase_get_config__() else: skbase_get_config_extension_dict = {} + + # If the config extension dunder returned a dict, use it to update + # the dict of configs that have been retrieved so far if isinstance(skbase_get_config_extension_dict, dict): config.update(skbase_get_config_extension_dict) + # Otherwise warn the user that a dict wasn't returned (extension not + # done properly) and ignore the returned result else: msg = "Use of `__skbase_get_config__` to extend the interface for local " msg += "overrides of the global configuration must return a dictionary.\n" msg += f"But a {type(skbase_get_config_extension_dict)} was found." + msg += "Ignoring result returned from `__skbase_get_config__`." warnings.warn(msg, UserWarning, stacklevel=2) + + # Finally get the local config in case any optional instance config + # overrides were made) local_config = self._get_flags(flag_attr_name="_config").copy() - # IF the local config is one of + + # If the local config param name is one of the skbase global config + # options (as opposed to config of downstream package) then we want + # to make sure we don't return an invalid value. In this case, we'll + # fallback to the default value if an invalid value was set as the local + # override of the global config for config_param, config_value in local_config.items(): if config_param in _CONFIG_REGISTRY: msg = "Invalid value encountered for global configuration parameter " From 6a69f1a04adbf7aae49dc575b62972b0f81e785b Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Wed, 15 Mar 2023 20:55:25 -0400 Subject: [PATCH 11/12] Move GlobalConfigParamSetting to its own file and expand test coverage --- skbase/config/_config.py | 105 ++----------------------- skbase/config/_config_param_setting.py | 91 +++++++++++++++++++++ skbase/config/tests/test_config.py | 46 +++++++---- skbase/tests/conftest.py | 6 +- 4 files changed, 133 insertions(+), 115 deletions(-) create mode 100644 skbase/config/_config_param_setting.py diff --git a/skbase/config/_config.py b/skbase/config/_config.py index ec2fc13c..b5581b31 100644 --- a/skbase/config/_config.py +++ b/skbase/config/_config.py @@ -6,24 +6,20 @@ # that is similar to scikit-learn. These elements are copyrighted by the # scikit-learn developers, BSD-3-Clause License. For conditions see # https://github.com/scikit-learn/scikit-learn/blob/main/COPYING -import collections import sys import threading -import warnings from contextlib import contextmanager -from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional if sys.version_info >= (3, 8): from typing import Literal else: from typing_extensions import Literal # type: ignore -from skbase.utils._iter import _format_seq_to_str +from skbase.config._config_param_setting import GlobalConfigParamSetting __author__: List[str] = ["RNKuhns"] __all__: List[str] = [ - "GlobalConfigParamSetting", "get_default_config", "get_config", "set_config", @@ -32,112 +28,28 @@ ] -@dataclass -class GlobalConfigParamSetting: - """Define types of the setting information for a given config parameter.""" - - name: str - os_environ_name: str - expected_type: Union[type, Tuple[type]] - allowed_values: Optional[Union[Tuple[Any, ...], List[Any]]] - default_value: Any - - def get_allowed_values(self) -> List[Any]: - """Get `allowed_values` or empty tuple if `allowed_values` is None. - - Returns - ------- - tuple - Allowable values if any. - """ - if self.allowed_values is None: - return [] - elif isinstance(self.allowed_values, list): - return self.allowed_values - elif isinstance( - self.allowed_values, collections.abc.Iterable - ) and not isinstance(self.allowed_values, str): - return list(self.allowed_values) - else: - return [self.allowed_values] - - def is_valid_param_value(self, value): - """Validate that a global configuration value is valid. - - Verifies that the value set for a global configuration parameter is valid - based on the its configuration settings. - - Returns - ------- - bool - Whether a parameter value is valid. - """ - allowed_values = self.get_allowed_values() - - valid_param: bool - if not isinstance(value, self.expected_type): - valid_param = False - elif allowed_values is not None and value not in allowed_values: - valid_param = False - else: - valid_param = True - return valid_param - - def get_valid_param_or_default(self, value, default_value=None, msg=None): - """Validate `value` and return default if it is not valid. - - Parameters - ---------- - value : Any - The configuration parameter value to set. - default_value : Any, default=None - An optional default value to use to set the configuration parameter - if `value` is not valid based on defined expected type and allowed - values. If None, and `value` is invalid then the classes `default_value` - parameter is used. - msg : str, default=None - An optional message to be used as start of the UserWarning message. - """ - if self.is_valid_param_value(value): - return value - else: - if msg is None: - msg = "" - msg + f"When setting global config values for `{self.name}`, the values " - msg += f"should be of type {self.expected_type}.\n" - if self.allowed_values is not None: - values_str = _format_seq_to_str( - self.get_allowed_values(), last_sep="or", remove_type_text=True - ) - msg += f"Allowed values are be one of {values_str}. " - msg += f"But found {value}." - warnings.warn(msg, UserWarning, stacklevel=2) - return default_value if default_value is not None else self.default_value - - _CONFIG_REGISTRY: Dict[str, GlobalConfigParamSetting] = { "print_changed_only": GlobalConfigParamSetting( name="print_changed_only", - os_environ_name="SKBASE_PRINT_CHANGED_ONLY", expected_type=bool, allowed_values=(True, False), default_value=True, ), "display": GlobalConfigParamSetting( name="display", - os_environ_name="SKBASE_OBJECT_DISPLAY", expected_type=str, allowed_values=("text", "diagram"), default_value="text", ), } -_DEFAULT_GLOBAL_CONFIG: Dict[str, Any] = { - config_name: config_info.default_value - for config_name, config_info in _CONFIG_REGISTRY.items() +_GLOBAL_CONFIG_DEFAULT: Dict[str, Any] = { + config_settings.name: config_settings.default_value + for config_name, config_settings in _CONFIG_REGISTRY.items() } -global_config = _DEFAULT_GLOBAL_CONFIG.copy() +global_config = _GLOBAL_CONFIG_DEFAULT.copy() + _THREAD_LOCAL_DATA = threading.local() @@ -183,7 +95,7 @@ def get_default_config() -> Dict[str, Any]: >>> get_default_config() {'print_changed_only': True, 'display': 'text'} """ - return _DEFAULT_GLOBAL_CONFIG.copy() + return _GLOBAL_CONFIG_DEFAULT.copy() def get_config() -> Dict[str, Any]: @@ -268,7 +180,6 @@ def set_config( {'print_changed_only': True, 'display': 'diagram'} """ local_config = _get_threadlocal_config() - msg = "Attempting to set an invalid value for a global configuration.\n" msg += "Using current configuration value of parameter as a result.\n" if print_changed_only is not None: diff --git a/skbase/config/_config_param_setting.py b/skbase/config/_config_param_setting.py new file mode 100644 index 00000000..43dee432 --- /dev/null +++ b/skbase/config/_config_param_setting.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +"""Class to hold information on a configurable parameter setting.""" +import collections +import warnings +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple, Union + +from skbase.utils._iter import _format_seq_to_str + +__author__: List[str] = ["RNKuhns"] +__all__: List[str] = ["GlobalConfigParamSetting"] + + +@dataclass +class GlobalConfigParamSetting: + """Define types of the setting information for a given config parameter.""" + + name: str + expected_type: Union[type, Tuple[type]] + allowed_values: Optional[Union[Tuple[Any, ...], List[Any]]] + default_value: Any + + def get_allowed_values(self) -> List[Any]: + """Get `allowed_values` or empty tuple if `allowed_values` is None. + + Returns + ------- + tuple + Allowable values if any. + """ + if self.allowed_values is None: + return [] + elif isinstance( + self.allowed_values, collections.abc.Iterable + ) and not isinstance(self.allowed_values, str): + return list(self.allowed_values) + else: + return [self.allowed_values] + + def is_valid_param_value(self, value): + """Validate that a global configuration value is valid. + + Verifies that the value set for a global configuration parameter is valid + based on the its configuration settings. + + Returns + ------- + bool + Whether a parameter value is valid. + """ + allowed_values = self.get_allowed_values() + + valid_param: bool + if not isinstance(value, self.expected_type): + valid_param = False + elif allowed_values is not None and value not in allowed_values: + valid_param = False + else: + valid_param = True + return valid_param + + def get_valid_param_or_default(self, value, default_value=None, msg=None): + """Validate `value` and return default if it is not valid. + + Parameters + ---------- + value : Any + The configuration parameter value to set. + default_value : Any, default=None + An optional default value to use to set the configuration parameter + if `value` is not valid based on defined expected type and allowed + values. If None, and `value` is invalid then the classes `default_value` + parameter is used. + msg : str, default=None + An optional message to be used as start of the UserWarning message. + """ + if self.is_valid_param_value(value): + return value + else: + if msg is None: + msg = "" + msg += f"When setting global config values for `{self.name}`, the values " + msg += f"should be of type {self.expected_type}.\n" + if self.allowed_values is not None: + values_str = _format_seq_to_str( + self.get_allowed_values(), last_sep="or", remove_type_text=True + ) + msg += f"Allowed values should be one of {values_str}. " + msg += f"But found {value}." + warnings.warn(msg, UserWarning, stacklevel=2) + return default_value if default_value is not None else self.default_value diff --git a/skbase/config/tests/test_config.py b/skbase/config/tests/test_config.py index 1af163e4..75111211 100644 --- a/skbase/config/tests/test_config.py +++ b/skbase/config/tests/test_config.py @@ -3,29 +3,23 @@ import pytest from skbase.config import ( - GlobalConfigParamSetting, config_context, get_config, get_default_config, reset_config, set_config, ) -from skbase.config._config import _CONFIG_REGISTRY, _DEFAULT_GLOBAL_CONFIG +from skbase.config._config import _CONFIG_REGISTRY, _GLOBAL_CONFIG_DEFAULT +from skbase.config._config_param_setting import GlobalConfigParamSetting PRINT_CHANGE_ONLY_VALUES = _CONFIG_REGISTRY["print_changed_only"].get_allowed_values() DISPLAY_VALUES = _CONFIG_REGISTRY["display"].get_allowed_values() -@pytest.fixture -def config_registry(): - """Config registry fixture.""" - return _CONFIG_REGISTRY - - @pytest.fixture def global_config_default(): """Config registry fixture.""" - return _DEFAULT_GLOBAL_CONFIG + return _GLOBAL_CONFIG_DEFAULT @pytest.mark.parametrize("allowed_values", (None, (), "something", range(1, 8))) @@ -33,7 +27,6 @@ def test_global_config_param_get_allowed_values(allowed_values): """Test GlobalConfigParamSetting behavior works as expected.""" some_config_param = GlobalConfigParamSetting( name="some_param", - os_environ_name="SKBASE_OBJECT_DISPLAY", expected_type=str, allowed_values=allowed_values, default_value="text", @@ -48,7 +41,6 @@ def test_global_config_param_is_valid_param_value(value): """Test GlobalConfigParamSetting behavior works as expected.""" some_config_param = GlobalConfigParamSetting( name="some_param", - os_environ_name="SKBASE_OBJECT_DISPLAY", expected_type=str, allowed_values=("text", "diagram"), default_value="text", @@ -61,10 +53,36 @@ def test_global_config_param_is_valid_param_value(value): assert some_config_param.is_valid_param_value(value) == expected_valid -def test_get_default_config(global_config_default): - """Test get_default_config alwasy returns the default config.""" +def test_global_config_get_valid_or_default(): + """Test GlobalConfigParamSetting.get_valid_param_or_default works as expected.""" + some_config_param = GlobalConfigParamSetting( + name="some_param", + expected_type=str, + allowed_values=("text", "diagram"), + default_value="text", + ) + + # When calling the method with invalid value it should raise user warning + with pytest.warns(UserWarning, match=r"When setting global.*"): + returned_value = some_config_param.get_valid_param_or_default(7, msg=None) + + # And it should return the default value instead of the passed value + assert returned_value == some_config_param.default_value + + # When calling the method with invalid value it should raise user warning + # And the warning should start with `msg` if it is passed + with pytest.warns(UserWarning, match=r"some message.*"): + returned_value = some_config_param.get_valid_param_or_default( + 7, msg="some message" + ) + + +def test_get_default_config_always_returns_default(global_config_default): + """Test get_default_config always returns the default config.""" assert get_default_config() == global_config_default - set_config(print_changed_only=False) + # set config to non-default value to make sure it didn't change + # what is returned by get_default_config + set_config(print_changed_only=not get_default_config()["print_changed_only"]) assert get_default_config() == global_config_default diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index a39e2251..9e80f46e 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -25,6 +25,7 @@ "skbase.base._tagmanager", "skbase.config", "skbase.config._config", + "skbase.config._config_param_setting", "skbase.lookup", "skbase.lookup.tests", "skbase.lookup.tests.test_lookup", @@ -78,7 +79,7 @@ "skbase.base._base": ("BaseEstimator", "BaseObject"), "skbase.base._meta": ("BaseMetaEstimator",), "skbase.config": ("GlobalConfigParamSetting",), - "skbase.config._config": ("GlobalConfigParamSetting",), + "skbase.config._config_param_setting": ("GlobalConfigParamSetting",), "skbase.lookup._lookup": ("ClassInfo", "FunctionInfo", "ModuleInfo"), "skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"), "skbase.testing.test_all_objects": ( @@ -203,9 +204,6 @@ "check_type", "is_sequence", "_convert_scalar_seq_type_input_to_tuple", - # "_check_iterable_of_class_or_error", - # "_check_list_of_str", - # "_check_list_of_str_or_error", ), } ) From 91cbe4f479ef6db7e9b9f132da2a47f47fbab58e Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Wed, 22 Mar 2023 01:15:11 -0400 Subject: [PATCH 12/12] Add test coverage for threadsafe and no positional config params --- skbase/config/_config.py | 2 + skbase/config/tests/test_config.py | 154 ++++++++++++++++++++++++++--- 2 files changed, 145 insertions(+), 11 deletions(-) diff --git a/skbase/config/_config.py b/skbase/config/_config.py index b5581b31..f7e3b07d 100644 --- a/skbase/config/_config.py +++ b/skbase/config/_config.py @@ -130,6 +130,7 @@ def get_config() -> Dict[str, Any]: def set_config( + *, print_changed_only: Optional[bool] = None, display: Literal["text", "diagram"] = None, local_threadsafe: bool = False, @@ -244,6 +245,7 @@ def reset_config() -> None: @contextmanager def config_context( + *, print_changed_only: Optional[bool] = None, display: Literal["text", "diagram"] = None, local_threadsafe: bool = False, diff --git a/skbase/config/tests/test_config.py b/skbase/config/tests/test_config.py index 75111211..364b5bb1 100644 --- a/skbase/config/tests/test_config.py +++ b/skbase/config/tests/test_config.py @@ -1,6 +1,13 @@ # -*- coding: utf-8 -*- +# This code reuses some tests in scikit-learn tests/test_config.py +# The code is copyrighted by the respective scikit-learn developers (BSD-3-Clause +# License): https://github.com/scikit-learn/scikit-learn/blob/main/COPYING """Test configuration functionality.""" +import time +from concurrent.futures import ThreadPoolExecutor + import pytest +from joblib import Parallel, delayed from skbase.config import ( config_context, @@ -84,6 +91,7 @@ def test_get_default_config_always_returns_default(global_config_default): # what is returned by get_default_config set_config(print_changed_only=not get_default_config()["print_changed_only"]) assert get_default_config() == global_config_default + reset_config() @pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) @@ -97,6 +105,37 @@ def test_set_config_then_get_config_returns_expected_value(print_changed_only, d msg += "After set_config is run, get_config should return the set values.\n " msg += f"Expected {expected_config}, but returned {retrieved_default}." assert retrieved_default == expected_config, msg + reset_config() + + +def test_set_config_with_none_value(): + """Verify that setting config param with None value does not update the config.""" + assert get_config()["print_changed_only"] is True + set_config(print_changed_only=None) + assert get_config()["print_changed_only"] is True + + set_config(print_changed_only=False) + assert get_config()["print_changed_only"] is False + + set_config(print_changed_only=None) + assert get_config()["print_changed_only"] is False + set_config(print_changed_only=True) + assert get_config()["print_changed_only"] is True + + reset_config() + + +def test_config_context_exception(): + """Test config_context followed by exception.""" + assert get_config()["print_changed_only"] is True + try: + with config_context(print_changed_only=False): + assert get_config()["print_changed_only"] is False + raise ValueError() + except ValueError: + pass + assert get_config()["print_changed_only"] is True + reset_config() @pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES) @@ -123,16 +162,20 @@ def test_config_context(print_changed_only, display): # Make sure config is reset to default values then retrieve it reset_config() retrieved_config = get_config() + + # Verify that nothing happens if not used as context manager + config_context(print_changed_only=print_changed_only) + assert get_config() == retrieved_config + # Now lets make sure the config_context is changing the context of those values # within the scope of the context manager as expected - for print_changed_only in (True, False): - with config_context(print_changed_only=print_changed_only, display=display): - retrieved_context_config = get_config() - expected_config = {"print_changed_only": print_changed_only, "display": display} - msg = "`get_config` does not return expected values within `config_context`.\n" - msg += "`get_config` should return config defined by `config_context`.\n" - msg += f"Expected {expected_config}, but returned {retrieved_context_config}." - assert retrieved_context_config == expected_config, msg + with config_context(print_changed_only=print_changed_only, display=display): + retrieved_context_config = get_config() + expected_config = {"print_changed_only": print_changed_only, "display": display} + msg = "`get_config` does not return expected values within `config_context`.\n" + msg += "`get_config` should return config defined by `config_context`.\n" + msg += f"Expected {expected_config}, but returned {retrieved_context_config}." + assert retrieved_context_config == expected_config, msg # Outside of the config_context we should have not affected the retrieved config # set by call to reset_config() @@ -141,21 +184,110 @@ def test_config_context(print_changed_only, display): msg += "`config_context` should not affect configuration outside its context.\n" msg += f"Expected {config_post_config_context}, but returned {retrieved_config}." assert retrieved_config == config_post_config_context, msg - reset_config() + + # positional arguments not allowed + with pytest.raises(TypeError): + config_context(True) + + # Unknown arguments raise exception + with pytest.raises(TypeError): + config_context(do_something_else=True).__enter__() + + +def test_nested_config_context(): + """Verify that nested config_context works as expected.""" + with config_context(print_changed_only=False): + with config_context(print_changed_only=None): + assert get_config()["print_changed_only"] is False + + assert get_config()["print_changed_only"] is False + + with config_context(print_changed_only=True): + assert get_config()["print_changed_only"] is True + + with config_context(print_changed_only=None): + assert get_config()["print_changed_only"] is True + + # global setting will not be retained outside of context that + # did not modify this setting + set_config(print_changed_only=False) + assert get_config()["print_changed_only"] is False + + assert get_config()["print_changed_only"] is True + + assert get_config()["print_changed_only"] is False + assert get_config()["print_changed_only"] is True def test_set_config_behavior_invalid_value(): """Test set_config uses default and raises warning when setting invalid value.""" reset_config() - original_config = get_config().copy() + original_config = get_config() with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"): set_config(print_changed_only="False") assert get_config() == original_config - original_config = get_config().copy() + original_config = get_config() with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"): set_config(print_changed_only=7) assert get_config() == original_config reset_config() + + +def test_set_config_invalid_keyword_argument(): + """Test set_config behavior when invalid keyword argument passed.""" + with pytest.raises(TypeError): + set_config(do_something_else=True) + + with pytest.raises(TypeError): + set_config(True) + + +def _set_print_changed_only(print_changed_only, sleep_duration): + """Return the value of print_changed_only after waiting `sleep_duration`.""" + with config_context(print_changed_only=print_changed_only): + time.sleep(sleep_duration) + return get_config()["print_changed_only"] + + +def test_config_threadsafe(): + """Test config is actually threadsafe. + + Uses threads directly to test that the global config does not change + between threads. Same test as `test_config_threadsafe_joblib` but with + `ThreadPoolExecutor`. + """ + + print_changed_only_vals = [False, True, False, True] + sleep_durations = [0.1, 0.2, 0.1, 0.2] + + with ThreadPoolExecutor(max_workers=2) as e: + items = list( + e.map(_set_print_changed_only, print_changed_only_vals, sleep_durations) + ) + + assert items == [False, True, False, True] + + +@pytest.mark.parametrize("backend", ["loky", "multiprocessing", "threading"]) +def test_config_threadsafe_joblib(backend): + """Test that the global config is threadsafe with all joblib backends. + + Two jobs are spawned and sets assume_finite to two different values. + When the job with a duration 0.1s completes, the assume_finite value + should be the same as the value passed to the function. In other words, + it is not influenced by the other job setting assume_finite to True. + """ + print_changed_only_vals = [False, True, False, True] + sleep_durations = [0.1, 0.2, 0.1, 0.2] + + items = Parallel(backend=backend, n_jobs=2)( + delayed(_set_print_changed_only)(print_changed_only, sleep_dur) + for print_changed_only, sleep_dur in zip( + print_changed_only_vals, sleep_durations + ) + ) + + assert items == [False, True, False, True]