Skip to content

Commit

Permalink
[BUG] ensure deep_equals plugins are passed on to all recursions (#243
Browse files Browse the repository at this point in the history
)

This PR fixes an unreported bug that prevented `deep_equals` plugins to
be passed in some recursions.

The recursions have been fixed, and test cases have been added.
  • Loading branch information
fkiraly authored Nov 18, 2023
1 parent c0a674b commit c3d501e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
30 changes: 15 additions & 15 deletions skbase/utils/deep_equals/_deep_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
lists, tuples, or dicts of a valid type (recursive)
"""
from inspect import isclass, signature
from typing import List

from skbase.utils.deep_equals._common import _make_ret

__author__: List[str] = ["fkiraly"]
__all__: List[str] = ["deep_equals"]
__author__ = ["fkiraly"]
__all__ = ["deep_equals"]


# flag variables for available soft dependencies
Expand Down Expand Up @@ -96,8 +95,9 @@ def _is_pandas(x):


def _is_npndarray(x):
clstr = type(x).__name__
return clstr == "ndarray"
import numpy as np

return isinstance(x, np.ndarray)


def _is_npnan(x):
Expand Down Expand Up @@ -198,7 +198,7 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None):
)


def _tuple_equals(x, y, return_msg=False):
def _tuple_equals(x, y, return_msg=False, deep_equals=None):
"""Test two tuples or lists for equality.
Correct if tuples/lists contain the following valid types:
Expand Down Expand Up @@ -243,7 +243,7 @@ def _tuple_equals(x, y, return_msg=False):
return ret(True, "")


def _dict_equals(x, y, return_msg=False):
def _dict_equals(x, y, return_msg=False, deep_equals=None):
"""Test two dicts for equality.
Correct if dicts contain the following valid types:
Expand Down Expand Up @@ -378,11 +378,17 @@ def deep_equals_custom(x, y, return_msg=False, plugins=None):
# we now know all types are the same
# so now we compare values

# we need to pass in the same plugins, so we curry
def deep_equals_curried(x, y, return_msg=False):
return deep_equals_custom(x, y, return_msg=return_msg, plugins=plugins)

# recursion through lists, tuples and dicts
if isinstance(x, (list, tuple)):
return ret(*_tuple_equals(x, y, return_msg=True))
dec = deep_equals_curried
return ret(*_tuple_equals(x, y, return_msg=True, deep_equals=dec))
elif isinstance(x, dict):
return ret(*_dict_equals(x, y, return_msg=True))
dec = deep_equals_curried
return ret(*_dict_equals(x, y, return_msg=True, deep_equals=dec))
elif _is_npnan(x):
return ret(_is_npnan(y), f"type(x)={type(x)} != type(y)={type(y)}")
elif isclass(x):
Expand All @@ -398,12 +404,6 @@ def deep_equals_custom(x, y, return_msg=False, plugins=None):
sig = signature(plugin)
# check if deep_equals is an argument of the plugin
if "deep_equals" in sig.parameters:
# we need to pass in the same plugins, so we curry
def deep_equals_curried(x, y, return_msg=False):
return deep_equals_custom(
x, y, return_msg=return_msg, plugins=plugins
)

kwargs = {"deep_equals": deep_equals_curried}
else:
kwargs = {}
Expand Down
4 changes: 4 additions & 0 deletions skbase/utils/tests/test_deep_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
np.array([2, 3, 4]),
np.array([2, 4, 5]),
np.nan,
# these cases test that plugins are passed to recursions
# in this case, the numpy equality plugin
{"a": np.array([2, 3, 4]), "b": np.array([4, 3, 2])},
[np.array([2, 3, 4]), np.array([4, 3, 2])],
]

if _check_soft_dependencies("pandas", severity="none"):
Expand Down

0 comments on commit c3d501e

Please sign in to comment.