Skip to content

Commit

Permalink
Add test coverage for threadsafe and no positional config params
Browse files Browse the repository at this point in the history
  • Loading branch information
RNKuhns committed Mar 22, 2023
1 parent 39a4648 commit 91cbe4f
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 11 deletions.
2 changes: 2 additions & 0 deletions skbase/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def get_config() -> Dict[str, Any]:


def set_config(
*,
print_changed_only: Optional[bool] = None,
display: Literal["text", "diagram"] = None,
local_threadsafe: bool = False,
Expand Down Expand Up @@ -244,6 +245,7 @@ def reset_config() -> None:

@contextmanager
def config_context(
*,
print_changed_only: Optional[bool] = None,
display: Literal["text", "diagram"] = None,
local_threadsafe: bool = False,
Expand Down
154 changes: 143 additions & 11 deletions skbase/config/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# -*- coding: utf-8 -*-
# This code reuses some tests in scikit-learn tests/test_config.py
# The code is copyrighted by the respective scikit-learn developers (BSD-3-Clause
# License): https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
"""Test configuration functionality."""
import time
from concurrent.futures import ThreadPoolExecutor

import pytest
from joblib import Parallel, delayed

from skbase.config import (
config_context,
Expand Down Expand Up @@ -84,6 +91,7 @@ def test_get_default_config_always_returns_default(global_config_default):
# what is returned by get_default_config
set_config(print_changed_only=not get_default_config()["print_changed_only"])
assert get_default_config() == global_config_default
reset_config()


@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES)
Expand All @@ -97,6 +105,37 @@ def test_set_config_then_get_config_returns_expected_value(print_changed_only, d
msg += "After set_config is run, get_config should return the set values.\n "
msg += f"Expected {expected_config}, but returned {retrieved_default}."
assert retrieved_default == expected_config, msg
reset_config()


def test_set_config_with_none_value():
"""Verify that setting config param with None value does not update the config."""
assert get_config()["print_changed_only"] is True
set_config(print_changed_only=None)
assert get_config()["print_changed_only"] is True

set_config(print_changed_only=False)
assert get_config()["print_changed_only"] is False

set_config(print_changed_only=None)
assert get_config()["print_changed_only"] is False
set_config(print_changed_only=True)
assert get_config()["print_changed_only"] is True

reset_config()


def test_config_context_exception():
"""Test config_context followed by exception."""
assert get_config()["print_changed_only"] is True
try:
with config_context(print_changed_only=False):
assert get_config()["print_changed_only"] is False
raise ValueError()
except ValueError:
pass
assert get_config()["print_changed_only"] is True
reset_config()


@pytest.mark.parametrize("print_changed_only", PRINT_CHANGE_ONLY_VALUES)
Expand All @@ -123,16 +162,20 @@ def test_config_context(print_changed_only, display):
# Make sure config is reset to default values then retrieve it
reset_config()
retrieved_config = get_config()

# Verify that nothing happens if not used as context manager
config_context(print_changed_only=print_changed_only)
assert get_config() == retrieved_config

# Now lets make sure the config_context is changing the context of those values
# within the scope of the context manager as expected
for print_changed_only in (True, False):
with config_context(print_changed_only=print_changed_only, display=display):
retrieved_context_config = get_config()
expected_config = {"print_changed_only": print_changed_only, "display": display}
msg = "`get_config` does not return expected values within `config_context`.\n"
msg += "`get_config` should return config defined by `config_context`.\n"
msg += f"Expected {expected_config}, but returned {retrieved_context_config}."
assert retrieved_context_config == expected_config, msg
with config_context(print_changed_only=print_changed_only, display=display):
retrieved_context_config = get_config()
expected_config = {"print_changed_only": print_changed_only, "display": display}
msg = "`get_config` does not return expected values within `config_context`.\n"
msg += "`get_config` should return config defined by `config_context`.\n"
msg += f"Expected {expected_config}, but returned {retrieved_context_config}."
assert retrieved_context_config == expected_config, msg

# Outside of the config_context we should have not affected the retrieved config
# set by call to reset_config()
Expand All @@ -141,21 +184,110 @@ def test_config_context(print_changed_only, display):
msg += "`config_context` should not affect configuration outside its context.\n"
msg += f"Expected {config_post_config_context}, but returned {retrieved_config}."
assert retrieved_config == config_post_config_context, msg
reset_config()

# positional arguments not allowed
with pytest.raises(TypeError):
config_context(True)

# Unknown arguments raise exception
with pytest.raises(TypeError):
config_context(do_something_else=True).__enter__()


def test_nested_config_context():
"""Verify that nested config_context works as expected."""
with config_context(print_changed_only=False):
with config_context(print_changed_only=None):
assert get_config()["print_changed_only"] is False

assert get_config()["print_changed_only"] is False

with config_context(print_changed_only=True):
assert get_config()["print_changed_only"] is True

with config_context(print_changed_only=None):
assert get_config()["print_changed_only"] is True

# global setting will not be retained outside of context that
# did not modify this setting
set_config(print_changed_only=False)
assert get_config()["print_changed_only"] is False

assert get_config()["print_changed_only"] is True

assert get_config()["print_changed_only"] is False
assert get_config()["print_changed_only"] is True


def test_set_config_behavior_invalid_value():
"""Test set_config uses default and raises warning when setting invalid value."""
reset_config()
original_config = get_config().copy()
original_config = get_config()
with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"):
set_config(print_changed_only="False")

assert get_config() == original_config

original_config = get_config().copy()
original_config = get_config()
with pytest.warns(UserWarning, match=r"Attempting to set an invalid value.*"):
set_config(print_changed_only=7)

assert get_config() == original_config
reset_config()


def test_set_config_invalid_keyword_argument():
"""Test set_config behavior when invalid keyword argument passed."""
with pytest.raises(TypeError):
set_config(do_something_else=True)

with pytest.raises(TypeError):
set_config(True)


def _set_print_changed_only(print_changed_only, sleep_duration):
"""Return the value of print_changed_only after waiting `sleep_duration`."""
with config_context(print_changed_only=print_changed_only):
time.sleep(sleep_duration)
return get_config()["print_changed_only"]


def test_config_threadsafe():
"""Test config is actually threadsafe.
Uses threads directly to test that the global config does not change
between threads. Same test as `test_config_threadsafe_joblib` but with
`ThreadPoolExecutor`.
"""

print_changed_only_vals = [False, True, False, True]
sleep_durations = [0.1, 0.2, 0.1, 0.2]

with ThreadPoolExecutor(max_workers=2) as e:
items = list(
e.map(_set_print_changed_only, print_changed_only_vals, sleep_durations)
)

assert items == [False, True, False, True]


@pytest.mark.parametrize("backend", ["loky", "multiprocessing", "threading"])
def test_config_threadsafe_joblib(backend):
"""Test that the global config is threadsafe with all joblib backends.
Two jobs are spawned and sets assume_finite to two different values.
When the job with a duration 0.1s completes, the assume_finite value
should be the same as the value passed to the function. In other words,
it is not influenced by the other job setting assume_finite to True.
"""
print_changed_only_vals = [False, True, False, True]
sleep_durations = [0.1, 0.2, 0.1, 0.2]

items = Parallel(backend=backend, n_jobs=2)(
delayed(_set_print_changed_only)(print_changed_only, sleep_dur)
for print_changed_only, sleep_dur in zip(
print_changed_only_vals, sleep_durations
)
)

assert items == [False, True, False, True]

0 comments on commit 91cbe4f

Please sign in to comment.