From 9ae1c5593b78500c767f93a2c7c75f69b13c0308 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 18 Apr 2024 10:52:56 +0200 Subject: [PATCH 1/5] Fix computation of match value --- baybe/simulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/simulation.py b/baybe/simulation.py index bf2aa4ad7..faa626c6a 100644 --- a/baybe/simulation.py +++ b/baybe/simulation.py @@ -549,7 +549,7 @@ def simulate_experiment( agg_fun = np.min cum_fun = np.minimum.accumulate elif target.mode is TargetMode.MATCH: - match_val = np.mean(target.bounds) + match_val = target.bounds.center agg_fun = partial(closest_element, target=match_val) cum_fun = lambda x: np.array( # noqa: E731 np.frompyfunc( From bf18fd2c909ec49047b7c6ef271132bb2a935fd5 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 18 Apr 2024 11:22:01 +0200 Subject: [PATCH 2/5] Fix closest_element function * Handle array-like and scalar input * Handle arbitrary dimensionality * Add corresponding tests --- baybe/utils/numerical.py | 10 +++++++--- tests/test_utils.py | 26 ++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) create mode 100644 tests/test_utils.py diff --git a/baybe/utils/numerical.py b/baybe/utils/numerical.py index e78140f65..5b0444475 100644 --- a/baybe/utils/numerical.py +++ b/baybe/utils/numerical.py @@ -3,6 +3,7 @@ from collections.abc import Sequence import numpy as np +import numpy.typing as npt DTypeFloatNumpy = np.float64 """Floating point data type used for numpy arrays.""" @@ -33,16 +34,19 @@ def geom_mean(arr: np.ndarray, weights: Sequence[float]) -> np.ndarray: return np.prod(np.power(arr, np.atleast_2d(weights) / np.sum(weights)), axis=1) -def closest_element(array: np.ndarray, target: float) -> float: +def closest_element(array: npt.ArrayLike, target: float) -> float: """Find the element of an array that is closest to a target value. Args: - array: The array in which the closest value should be found. + array: The array in which the closest value is to be found. target: The target value. Returns: - The closes element. + The closest element. """ + if np.ndim(array) == 0: + return float(array) + array = np.ravel(array) return array[np.abs(array - target).argmin()] diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..db4934587 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,26 @@ +"""Tests for utilities.""" + +import numpy as np +import pytest +from pytest import param + +from baybe.utils.numerical import closest_element + + +@pytest.mark.parametrize( + "as_ndarray", [param(False, id="list"), param(True, id="array")] +) +@pytest.mark.parametrize( + ("array", "target", "expected"), + [ + param(0, 0.1, 0, id="0D"), + param([0, 1], 0.1, 0, id="1D"), + param([[2, 3], [0, 1]], 0.1, 0, id="2D"), + ], +) +def test_closest_element(as_ndarray, array, target, expected): + """The closest element can be found irrespective of the input type.""" + if as_ndarray: + array = np.asarray(array) + actual = closest_element(array, target) + assert actual == expected, (actual, expected) From b51a795b3cc8b249a5b02096a7708f25a126d10d Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 18 Apr 2024 11:25:52 +0200 Subject: [PATCH 3/5] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 78b0bb45e..a70ccf0a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `n_task_params` now evaluates to 1 if `task_idx == 0` - Simulation no longer fails in `ignore` mode when lookup dataframe contains duplicate parameter configurations +- Simulation no longer fails for targets in `MATCH` mode +- `closest_element` now works for array-like input of all kinds ### Deprecations - The former `baybe.objective.Objective` class has been replaced with From de2ba9ae101f1f4399372c77070eeb90654a9645 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 18 Apr 2024 12:06:01 +0200 Subject: [PATCH 4/5] Fix mypy error --- baybe/utils/numerical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/utils/numerical.py b/baybe/utils/numerical.py index 5b0444475..ab4d9c089 100644 --- a/baybe/utils/numerical.py +++ b/baybe/utils/numerical.py @@ -45,7 +45,7 @@ def closest_element(array: npt.ArrayLike, target: float) -> float: The closest element. """ if np.ndim(array) == 0: - return float(array) + return np.asarray(array).item() array = np.ravel(array) return array[np.abs(array - target).argmin()] From 0d2cf7b4695f7b419ab7970a09f8330cf09e2e7b Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 18 Apr 2024 14:44:19 +0200 Subject: [PATCH 5/5] Refactor test parametrization and use nonzero target --- tests/test_utils.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index db4934587..a02c1825d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,21 +6,24 @@ from baybe.utils.numerical import closest_element +_TARGET = 1337 +_CLOSEST = _TARGET + 0.1 + @pytest.mark.parametrize( "as_ndarray", [param(False, id="list"), param(True, id="array")] ) @pytest.mark.parametrize( - ("array", "target", "expected"), + "array", [ - param(0, 0.1, 0, id="0D"), - param([0, 1], 0.1, 0, id="1D"), - param([[2, 3], [0, 1]], 0.1, 0, id="2D"), + param(_CLOSEST, id="0D"), + param([0, _CLOSEST], id="1D"), + param([[2, 3], [0, _CLOSEST]], id="2D"), ], ) -def test_closest_element(as_ndarray, array, target, expected): +def test_closest_element(as_ndarray, array): """The closest element can be found irrespective of the input type.""" if as_ndarray: array = np.asarray(array) - actual = closest_element(array, target) - assert actual == expected, (actual, expected) + actual = closest_element(array, _TARGET) + assert actual == _CLOSEST, (actual, _CLOSEST)