Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add global config interface #149

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
21 changes: 21 additions & 0 deletions docs/source/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@ Base Classes
BaseObject
BaseEstimator

.. _global_config:

Configure ``skbase``
====================

.. automodule:: skbase.config
:no-members:
:no-inherited-members:

.. currentmodule:: skbase.config

.. autosummary::
:toctree: api_reference/auto_generated/
:template: function.rst

get_config
get_default_config
set_config
reset_config
config_context

.. _obj_retrieval:

Object Retrieval
Expand Down
15 changes: 15 additions & 0 deletions docs/source/user_documentation/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ that ``skbase`` provides, see the :ref:`api_ref`.
user_guide/lookup
user_guide/validate
user_guide/testing
user_guide/configuration


.. grid:: 1 2 2 2
Expand Down Expand Up @@ -103,3 +104,17 @@ that ``skbase`` provides, see the :ref:`api_ref`.
:expand:

Testing

.. grid-item-card:: Configuration
:text-align: center

Configure ``skbase``.

+++

.. button-ref:: user_guide/configuration
:color: primary
:click-parent:
:expand:

Configuration
11 changes: 11 additions & 0 deletions docs/source/user_documentation/user_guide/configuration.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
.. _user_guide_global_config:

====================
Configure ``skbase``
====================

.. note::

The user guide is under development. We have created a basic
structure and are looking for contributions to develop the user guide
further.
36 changes: 35 additions & 1 deletion skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class name: BaseEstimator

from skbase._exceptions import NotFittedError
from skbase.base._tagmanager import _FlagManager
from skbase.config import get_config
from skbase.config._config import _CONFIG_REGISTRY

__author__: List[str] = ["mloning", "RNKuhns", "fkiraly"]
__all__: List[str] = ["BaseEstimator", "BaseObject"]
Expand Down Expand Up @@ -446,7 +448,39 @@ def get_config(self):
class attribute via nested inheritance and then any overrides
and new tags from _onfig_dynamic object attribute.
"""
return self._get_flags(flag_attr_name="_config")
config = get_config().copy()

# Get any extension configuration interface defined in the class
# for example if downstream package wants to extend skbase to retrieve
# their own config
if hasattr(self, "__skbase_get_config__") and callable(
self.__skbase_get_config__
):
skbase_get_config_extension_dict = self.__skbase_get_config__()
else:
skbase_get_config_extension_dict = {}
if isinstance(skbase_get_config_extension_dict, dict):
config.update(skbase_get_config_extension_dict)
else:
msg = "Use of `__skbase_get_config__` to extend the interface for local "
msg += "overrides of the global configuration must return a dictionary.\n"
msg += f"But a {type(skbase_get_config_extension_dict)} was found."
warnings.warn(msg, UserWarning, stacklevel=2)
local_config = self._get_flags(flag_attr_name="_config").copy()
# IF the local config is one of
for config_param, config_value in local_config.items():
if config_param in _CONFIG_REGISTRY:
msg = "Invalid value encountered for global configuration parameter "
msg += f"{config_param}. Using global parameter configuration value.\n"
config_value = _CONFIG_REGISTRY[
config_param
].get_valid_param_or_default(
config_value, default_value=config[config_param]
)
local_config[config_param] = config_value
config.update(local_config)

return config

def set_config(self, **config_dict):
"""Set config flags to given values.
Expand Down
32 changes: 32 additions & 0 deletions skbase/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
""":mod:`skbase.config` provides tools for the global configuration of ``skbase``.

For more information on configuration usage patterns see the
:ref:`user guide <user_guide_global_config>`.
"""
# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
# Includes functionality like get_config, set_config, and config_context
# that is similar to scikit-learn. These elements are copyrighted by the
# scikit-learn developers, BSD-3-Clause License. For conditions see
# https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
from typing import List

from skbase.config._config import (
GlobalConfigParamSetting,
config_context,
get_config,
get_default_config,
reset_config,
set_config,
)

__author__: List[str] = ["RNKuhns"]
__all__: List[str] = [
"GlobalConfigParamSetting",
"get_default_config",
"get_config",
"set_config",
"reset_config",
"config_context",
]
Loading