diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 2ce9ea1f..27c591fc 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -49,6 +49,8 @@ "skbase.utils._nested_iter", "skbase.utils._utils", "skbase.utils.deep_equals", + "skbase.utils.deep_equals._common", + "skbase.utils.deep_equals._deep_equals", "skbase.utils.dependencies", "skbase.utils.dependencies._dependencies", "skbase.validate", @@ -160,6 +162,7 @@ ), "skbase.utils._utils": ("subset_dict_keys",), "skbase.utils.deep_equals": ("deep_equals",), + "skbase.utils.deep_equals._deep_equals": ("deep_equals", "deep_equals_custom"), "skbase.validate._types": ("check_sequence", "check_type", "is_sequence"), } SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy() @@ -207,19 +210,22 @@ "unflatten", ), "skbase.utils._utils": ("subset_dict_keys",), - "skbase.utils.deep_equals": ( + "skbase.utils.deep_equals": ("deep_equals",), + "skbase.utils.deep_equals._common": ("_make_ret", "_ret"), + "skbase.utils.deep_equals._deep_equals": ( "_coerce_list", "_dict_equals", - "_fh_equals", + "_fh_equals_plugin", "_is_npnan", "_is_npndarray", "_is_pandas", - "_make_ret", + "_numpy_equals_plugin", "_pandas_equals", - "_ret", + "_pandas_equals_plugin", "_softdep_available", "_tuple_equals", "deep_equals", + "deep_equals_custom", ), "skbase.utils.dependencies._dependencies": ( "_check_soft_dependencies", diff --git a/skbase/utils/deep_equals/__init__.py b/skbase/utils/deep_equals/__init__.py new file mode 100644 index 00000000..86eea814 --- /dev/null +++ b/skbase/utils/deep_equals/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +"""Module for nested equality checking.""" +from skbase.utils.deep_equals._deep_equals import deep_equals + +__all__ = [ + "deep_equals", +] diff --git a/skbase/utils/deep_equals/_common.py b/skbase/utils/deep_equals/_common.py new file mode 100644 index 00000000..bc486f23 --- /dev/null +++ b/skbase/utils/deep_equals/_common.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +"""Common utility functions for skbase.utils.deep_equals.""" + + +def _ret(is_equal, msg="", string_arguments: list = None, return_msg=False): + """Return is_equal and msg, formatted with string_arguments if return_msg=True. + + Parameters + ---------- + is_equal : bool + msg : str, optional, default="" + message to return if is_equal=False + string_arguments : list, optional, default=None + list of arguments to format msg with + + Returns + ------- + is_equal : bool + identical to input ``is_equal``, always returned + msg : str, only returned if return_msg=True + if ``is_equal=True``, ``msg`` is always ``""`` + if ``is_equal=False``, ``msg`` is formatted with ``string_arguments`` + via ``msg.format(*string_arguments)`` + """ + if return_msg: + if is_equal: + msg = "" + elif isinstance(string_arguments, (list, tuple)) and len(string_arguments) > 0: + msg = msg.format(*string_arguments) + return is_equal, msg + else: + return is_equal + + +def _make_ret(return_msg): + """Curry _ret with return_msg.""" + + def ret(is_equal, msg, string_arguments=None): + return _ret(is_equal, msg, string_arguments, return_msg) + + return ret diff --git a/skbase/utils/deep_equals.py b/skbase/utils/deep_equals/_deep_equals.py similarity index 66% rename from skbase/utils/deep_equals.py rename to skbase/utils/deep_equals/_deep_equals.py index c8a9a887..c4bbe160 100644 --- a/skbase/utils/deep_equals.py +++ b/skbase/utils/deep_equals/_deep_equals.py @@ -6,9 +6,11 @@ pd.Series, pd.DataFrame, np.ndarray lists, tuples, or dicts of a valid type (recursive) """ -from inspect import isclass +from inspect import isclass, signature from typing import List +from skbase.utils.deep_equals._common import _make_ret + __author__: List[str] = ["fkiraly"] __all__: List[str] = ["deep_equals"] @@ -27,11 +29,7 @@ def _softdep_available(importname): return True -numpy_available = _softdep_available("numpy") -pandas_available = _softdep_available("pandas") - - -def deep_equals(x, y, return_msg=False): +def deep_equals(x, y, return_msg=False, plugins=None): """Test two objects for equality in value. Correct if x/y are one of the following valid types: @@ -49,6 +47,10 @@ def deep_equals(x, y, return_msg=False): y : object return_msg : bool, optional, default=False whether to return informative message about what is not equal + plugins : list, optional, default=None + optional additional deep_equals plugins to use + will be appended to the default plugins from ``deep_equals_custom`` + see ``deep_equals_custom`` for details of signature of plugins Returns ------- @@ -71,58 +73,26 @@ def deep_equals(x, y, return_msg=False): [colname] - if pandas.DataFrame: column with name colname is not equal != - call to generic != returns False """ - ret = _make_ret(return_msg) - - if type(x) is not type(y): - return ret(False, f".type, x.type = {type(x)} != y.type = {type(y)}") - - # we now know all types are the same - # so now we compare values - - if numpy_available: - import numpy as np - - # pandas is a soft dependency, so we compare pandas objects separately - # and only if pandas is installed in the environment - if _is_pandas(x) and pandas_available: - res = _pandas_equals(x, y, return_msg=return_msg) - if res is not None: - return _pandas_equals(x, y, return_msg=return_msg) - - if numpy_available and _is_npndarray(x): - if x.dtype != y.dtype: - return ret(False, f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}") - return ret(np.array_equal(x, y, equal_nan=True), ".values") - # recursion through lists, tuples and dicts - elif isinstance(x, (list, tuple)): - return ret(*_tuple_equals(x, y, return_msg=True)) - elif isinstance(x, dict): - return ret(*_dict_equals(x, y, return_msg=True)) - elif _is_npnan(x): - return ret(_is_npnan(y), f"type(x)={type(x)} != type(y)={type(y)}") - elif isclass(x): - return ret(x == y, f".class, x={x.__name__} != y={y.__name__}") - elif type(x).__name__ == "ForecastingHorizon": - return ret(*_fh_equals(x, y, return_msg=True)) - # this elif covers case where != is boolean - # some types return a vector upon !=, this is covered in the next elif - elif isinstance(x == y, bool): - return ret(x == y, f" !=, {x} != {y}") - # deal with the case where != returns a vector - elif numpy_available and np.any(x != y) or any(_coerce_list(x != y)): - return ret(False, f" !=, {x} != {y}") + # call deep_equals_custom with default plugins + plugins_default = [ + _numpy_equals_plugin, + _pandas_equals_plugin, + _fh_equals_plugin, + ] + + if plugins is not None: + plugins_inner = plugins_default + plugins + else: + plugins_inner = plugins_default - return ret(True, "") + res = deep_equals_custom(x, y, return_msg=return_msg, plugins=plugins_inner) + return res def _is_pandas(x): - clstr = type(x).__name__ - if clstr in ["DataFrame", "Series"]: - return True - if clstr.endswith("Index"): - return True - else: - return False + import pandas as pd + + return isinstance(x, (pd.Series, pd.DataFrame, pd.Index)) def _is_npndarray(x): @@ -131,6 +101,8 @@ def _is_npndarray(x): def _is_npnan(x): + numpy_available = _softdep_available("numpy") + if numpy_available: import numpy as np @@ -150,7 +122,34 @@ def _coerce_list(x): return x -def _pandas_equals(x, y, return_msg=False): +def _numpy_equals_plugin(x, y, return_msg=False): + numpy_available = _softdep_available("numpy") + + if not numpy_available or not _is_npndarray(x): + return None + else: + import numpy as np + + ret = _make_ret(return_msg) + + if x.dtype != y.dtype: + return ret(False, f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}") + return ret(np.array_equal(x, y, equal_nan=True), ".values") + + +def _pandas_equals_plugin(x, y, return_msg=False, deep_equals=None): + pandas_available = _softdep_available("pandas") + + if not pandas_available or not _is_pandas(x): + return None + + # pandas is a soft dependency, so we compare pandas objects separately + # and only if pandas is installed in the environment + res = _pandas_equals(x, y, return_msg=return_msg, deep_equals=deep_equals) + return res + + +def _pandas_equals(x, y, return_msg=False, deep_equals=None): import pandas as pd ret = _make_ret(return_msg) @@ -162,7 +161,7 @@ def _pandas_equals(x, y, return_msg=False): if x.dtype == "object": index_equal = x.index.equals(y.index) values_equal, values_msg = deep_equals( - list(x.to_array()), list(y.to_array()), return_msg=True + list(x.to_numpy()), list(y.to_numpy()), return_msg=True ) if not values_equal: msg = ".values" + values_msg @@ -191,6 +190,12 @@ def _pandas_equals(x, y, return_msg=False): return ret(x.equals(y), ".df_equals, x = {} != y = {}", [x, y]) elif isinstance(x, pd.Index): return ret(x.equals(y), ".index_equals, x = {} != y = {}", [x, y]) + else: + raise RuntimeError( + f"Unexpected type of pandas object in _pandas_equals: type(x)={type(x)}," + f" type(y)={type(y)}, both should be one of " + "pd.Series, pd.DataFrame, pd.Index" + ) def _tuple_equals(x, y, return_msg=False): @@ -291,7 +296,7 @@ def _dict_equals(x, y, return_msg=False): return ret(True, "") -def _fh_equals(x, y, return_msg=False): +def _fh_equals_plugin(x, y, return_msg=False, deep_equals=None): """Test two forecasting horizons for equality. Correct if both x and y are ForecastingHorizon @@ -313,6 +318,9 @@ def _fh_equals(x, y, return_msg=False): .is_relative - x is absolute and y is relative, or vice versa .values - values of x and y are not equal """ + if type(x).__name__ != "ForecastingHorizon": + return None + ret = _make_ret(return_msg) if x.is_relative != y.is_relative: @@ -326,40 +334,98 @@ def _fh_equals(x, y, return_msg=False): return ret(True, "") -def _ret(is_equal, msg="", string_arguments: list = None, return_msg=False): - """Return is_equal and msg, formatted with string_arguments if return_msg=True. +def deep_equals_custom(x, y, return_msg=False, plugins=None): + """Test two objects for equality in value. + + Correct if x/y are one of the following valid types: + types compatible with != comparison + pd.Series, pd.DataFrame, np.ndarray + lists, tuples, or dicts of a valid type (recursive) + + Important note: + this function will return "not equal" if types of x,y are different + for instant, bool and numpy.bool are *not* considered equal Parameters ---------- - is_equal : bool - msg : str, optional, default="" - message to return if is_equal=False - string_arguments : list, optional, default=None - list of arguments to format msg with + x : object + y : object + return_msg : bool, optional, default=False + whether to return informative message about what is not equal + plugins : list, optional, default=None + list of plugins to use for custom deep_equals + entries must be functions with the signature: + ``(x, y, return_msg: bool) -> return`` + where return is: + ``None``, if the plugin does not apply, otheriwse: + ``is_equal: bool`` if ``return_msg=False``, + ``(is_equal: bool, msg: str)`` if return_msg=True. + Plugins can have an additional argument ``deep_equals=None`` + by which the parent function to be called recursively is passed Returns ------- - is_equal : bool - identical to input ``is_equal``, always returned - msg : str, only returned if return_msg=True - if ``is_equal=True``, ``msg`` is always ``""`` - if ``is_equal=False``, ``msg`` is formatted with ``string_arguments`` - via ``msg.format(*string_arguments)`` + is_equal: bool - True if x and y are equal in value + x and y do not need to be equal in reference + msg : str, only returned if return_msg = True + indication of what is the reason for not being equal """ - if return_msg: - if is_equal: - msg = "" - elif isinstance(string_arguments, (list, tuple)) and len(string_arguments) > 0: - msg = msg.format(*string_arguments) - return is_equal, msg - else: - return is_equal + ret = _make_ret(return_msg) + if type(x) is not type(y): + return ret(False, f".type, x.type = {type(x)} != y.type = {type(y)}") -def _make_ret(return_msg): - """Curry _ret with return_msg.""" + # we now know all types are the same + # so now we compare values - def ret(is_equal, msg, string_arguments=None): - return _ret(is_equal, msg, string_arguments, return_msg) + # recursion through lists, tuples and dicts + if isinstance(x, (list, tuple)): + return ret(*_tuple_equals(x, y, return_msg=True)) + elif isinstance(x, dict): + return ret(*_dict_equals(x, y, return_msg=True)) + elif _is_npnan(x): + return ret(_is_npnan(y), f"type(x)={type(x)} != type(y)={type(y)}") + elif isclass(x): + return ret(x == y, f".class, x={x.__name__} != y={y.__name__}") + + if plugins is not None: + for plugin in plugins: + # check if plugin has deep_equals argument + # if so, pass this function as argument to plugin + # this allows for recursive calls to deep_equals + + # get the signature of the plugin + sig = signature(plugin) + # check if deep_equals is an argument of the plugin + if "deep_equals" in sig.parameters: + # we need to pass in the same plugins, so we curry + def deep_equals_curried(x, y, return_msg=False): + return deep_equals_custom( + x, y, return_msg=return_msg, plugins=plugins + ) + + kwargs = {"deep_equals": deep_equals_curried} + else: + kwargs = {} - return ret + res = plugin(x, y, return_msg=return_msg, **kwargs) + + # if plugin does not apply, res is None + if res is not None: + return res + + # this if covers case where != is boolean + # some types return a vector upon !=, this is covered in the next elif + if isinstance(x == y, bool): + return ret(x == y, f" !=, {x} != {y}") + + # check if numpy is available + numpy_available = _softdep_available("numpy") + if numpy_available: + import numpy as np + + # deal with the case where != returns a vector + if numpy_available and np.any(x != y) or any(_coerce_list(x != y)): + return ret(False, f" !=, {x} != {y}") + + return ret(True, "") diff --git a/skbase/utils/tests/test_deep_equals.py b/skbase/utils/tests/test_deep_equals.py index 203ff65b..2ea2154f 100644 --- a/skbase/utils/tests/test_deep_equals.py +++ b/skbase/utils/tests/test_deep_equals.py @@ -33,6 +33,19 @@ {"bar": [42], "foo": pd.Series([1, 2])}, ] + # nested DataFrame example + cols = [f"var_{i}" for i in range(2)] + X = pd.DataFrame(columns=cols, index=[0, 1, 2]) + X["var_0"] = pd.Series( + [pd.Series([1, 2, 3]), pd.Series([1, 2, 3]), pd.Series([1, 2, 3])] + ) + + X["var_1"] = pd.Series( + [pd.Series([4, 5, 6]), pd.Series([4, 55, 6]), pd.Series([42, 5, 6])] + ) + + EXAMPLES += [X] + @pytest.mark.parametrize("fixture", EXAMPLES) def test_deep_equals_positive(fixture):