From c3d501e8723a91d770e067343dd51de1c548a5d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 18 Nov 2023 09:40:56 +0100 Subject: [PATCH] [BUG] ensure `deep_equals` plugins are passed on to all recursions (#243) This PR fixes an unreported bug that prevented `deep_equals` plugins to be passed in some recursions. The recursions have been fixed, and test cases have been added. --- skbase/utils/deep_equals/_deep_equals.py | 30 ++++++++++++------------ skbase/utils/tests/test_deep_equals.py | 4 ++++ 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/skbase/utils/deep_equals/_deep_equals.py b/skbase/utils/deep_equals/_deep_equals.py index c4bbe160..8ad87649 100644 --- a/skbase/utils/deep_equals/_deep_equals.py +++ b/skbase/utils/deep_equals/_deep_equals.py @@ -7,12 +7,11 @@ lists, tuples, or dicts of a valid type (recursive) """ from inspect import isclass, signature -from typing import List from skbase.utils.deep_equals._common import _make_ret -__author__: List[str] = ["fkiraly"] -__all__: List[str] = ["deep_equals"] +__author__ = ["fkiraly"] +__all__ = ["deep_equals"] # flag variables for available soft dependencies @@ -96,8 +95,9 @@ def _is_pandas(x): def _is_npndarray(x): - clstr = type(x).__name__ - return clstr == "ndarray" + import numpy as np + + return isinstance(x, np.ndarray) def _is_npnan(x): @@ -198,7 +198,7 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None): ) -def _tuple_equals(x, y, return_msg=False): +def _tuple_equals(x, y, return_msg=False, deep_equals=None): """Test two tuples or lists for equality. Correct if tuples/lists contain the following valid types: @@ -243,7 +243,7 @@ def _tuple_equals(x, y, return_msg=False): return ret(True, "") -def _dict_equals(x, y, return_msg=False): +def _dict_equals(x, y, return_msg=False, deep_equals=None): """Test two dicts for equality. Correct if dicts contain the following valid types: @@ -378,11 +378,17 @@ def deep_equals_custom(x, y, return_msg=False, plugins=None): # we now know all types are the same # so now we compare values + # we need to pass in the same plugins, so we curry + def deep_equals_curried(x, y, return_msg=False): + return deep_equals_custom(x, y, return_msg=return_msg, plugins=plugins) + # recursion through lists, tuples and dicts if isinstance(x, (list, tuple)): - return ret(*_tuple_equals(x, y, return_msg=True)) + dec = deep_equals_curried + return ret(*_tuple_equals(x, y, return_msg=True, deep_equals=dec)) elif isinstance(x, dict): - return ret(*_dict_equals(x, y, return_msg=True)) + dec = deep_equals_curried + return ret(*_dict_equals(x, y, return_msg=True, deep_equals=dec)) elif _is_npnan(x): return ret(_is_npnan(y), f"type(x)={type(x)} != type(y)={type(y)}") elif isclass(x): @@ -398,12 +404,6 @@ def deep_equals_custom(x, y, return_msg=False, plugins=None): sig = signature(plugin) # check if deep_equals is an argument of the plugin if "deep_equals" in sig.parameters: - # we need to pass in the same plugins, so we curry - def deep_equals_curried(x, y, return_msg=False): - return deep_equals_custom( - x, y, return_msg=return_msg, plugins=plugins - ) - kwargs = {"deep_equals": deep_equals_curried} else: kwargs = {} diff --git a/skbase/utils/tests/test_deep_equals.py b/skbase/utils/tests/test_deep_equals.py index 7f708479..271cf135 100644 --- a/skbase/utils/tests/test_deep_equals.py +++ b/skbase/utils/tests/test_deep_equals.py @@ -24,6 +24,10 @@ np.array([2, 3, 4]), np.array([2, 4, 5]), np.nan, + # these cases test that plugins are passed to recursions + # in this case, the numpy equality plugin + {"a": np.array([2, 3, 4]), "b": np.array([4, 3, 2])}, + [np.array([2, 3, 4]), np.array([4, 3, 2])], ] if _check_soft_dependencies("pandas", severity="none"):