Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] improved deep_equals utility - plugins for custom types #238

Merged
merged 39 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
680860c
Update deep_equals.py
fkiraly Oct 15, 2023
4f52869
Update deep_equals.py
fkiraly Oct 15, 2023
7bc7b00
refactor cont
fkiraly Oct 15, 2023
0b7042d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2023
f832b32
Update deep_equals.py
fkiraly Oct 15, 2023
ba7eb65
Merge branch 'deep_equals-improved' of https://github.com/sktime/skba…
fkiraly Oct 15, 2023
84073e8
docstr
fkiraly Oct 16, 2023
6d848d1
Update deep_equals.py
fkiraly Oct 16, 2023
53e17c8
split module
fkiraly Oct 16, 2023
b585cb1
Merge branch 'main' into deep_equals-improved-2
fkiraly Oct 16, 2023
64734c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 16, 2023
77bc577
deduplicate _ret
fkiraly Oct 16, 2023
e508e84
Update conftest.py
fkiraly Oct 16, 2023
150b56c
exports
fkiraly Oct 16, 2023
8744358
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 16, 2023
509596e
lint
fkiraly Oct 16, 2023
188eb22
Merge branch 'deep_equals-improved-2' of https://github.com/sktime/sk…
fkiraly Oct 16, 2023
464cc67
Update _deep_equals.py
fkiraly Oct 21, 2023
c2272a0
Update _deep_equals.py
fkiraly Oct 21, 2023
5afff87
Update _deep_equals.py
fkiraly Oct 21, 2023
f5838cf
clearer test
fkiraly Oct 21, 2023
d18cb5b
Update test_lookup.py
fkiraly Oct 21, 2023
f3ce05d
Update test_lookup.py
fkiraly Oct 22, 2023
b5b811f
Merge branch 'better-test_lookup-error' into deep_equals-improved-2
fkiraly Oct 22, 2023
80fd5b2
Update test_lookup.py
fkiraly Oct 22, 2023
761385f
Merge branch 'better-test_lookup-error' into deep_equals-improved-2
fkiraly Oct 22, 2023
9784966
Update conftest.py
fkiraly Oct 22, 2023
6aa235d
Update test_lookup.py
fkiraly Oct 22, 2023
b1d2ca3
Merge branch 'better-test_lookup-error' into deep_equals-improved-2
fkiraly Oct 22, 2023
9ae06f1
conftest
fkiraly Oct 22, 2023
142d1a3
one more test example
fkiraly Oct 22, 2023
4c22e60
Update _deep_equals.py
fkiraly Oct 22, 2023
f8b26ad
fixes
fkiraly Oct 22, 2023
00e90cb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2023
83d2fb1
Update _deep_equals.py
fkiraly Oct 22, 2023
8db89b5
Merge branch 'deep_equals-improved-2' of https://github.com/sktime/sk…
fkiraly Oct 22, 2023
23f9bfc
.to_numpy()
fkiraly Oct 22, 2023
8bc94e6
Update _deep_equals.py
fkiraly Oct 25, 2023
04b048d
Merge branch 'main' into deep_equals-improved-2
fkiraly Oct 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions skbase/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
"skbase.utils._nested_iter",
"skbase.utils._utils",
"skbase.utils.deep_equals",
"skbase.utils.deep_equals._common",
"skbase.utils.deep_equals._deep_equals",
"skbase.utils.dependencies",
"skbase.utils.dependencies._dependencies",
"skbase.validate",
Expand Down Expand Up @@ -207,16 +209,15 @@
"unflatten",
),
"skbase.utils._utils": ("subset_dict_keys",),
"skbase.utils.deep_equals": (
"skbase.utils.deep_equals": ("deep_equals",),
"skbase.utils._deep_equals": (
"_coerce_list",
"_dict_equals",
"_fh_equals",
"_is_npnan",
"_is_npndarray",
"_is_pandas",
"_make_ret",
"_pandas_equals",
"_ret",
"_softdep_available",
"_tuple_equals",
"deep_equals",
Expand Down
8 changes: 8 additions & 0 deletions skbase/utils/deep_equals/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
"""Module for nested equality checking."""
from skbase.utils.deep_equals._deep_equals import deep_equals

__all__ = [
"deep_equals",
]
41 changes: 41 additions & 0 deletions skbase/utils/deep_equals/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
"""Common utility functions for skbase.utils.deep_equals."""


def _ret(is_equal, msg="", string_arguments: list = None, return_msg=False):
"""Return is_equal and msg, formatted with string_arguments if return_msg=True.

Parameters
----------
is_equal : bool
msg : str, optional, default=""
message to return if is_equal=False
string_arguments : list, optional, default=None
list of arguments to format msg with

Returns
-------
is_equal : bool
identical to input ``is_equal``, always returned
msg : str, only returned if return_msg=True
if ``is_equal=True``, ``msg`` is always ``""``
if ``is_equal=False``, ``msg`` is formatted with ``string_arguments``
via ``msg.format(*string_arguments)``
"""
if return_msg:
if is_equal:
msg = ""
elif isinstance(string_arguments, (list, tuple)) and len(string_arguments) > 0:
msg = msg.format(*string_arguments)
return is_equal, msg
else:
return is_equal


def _make_ret(return_msg):
"""Curry _ret with return_msg."""

def ret(is_equal, msg, string_arguments=None):
return _ret(is_equal, msg, string_arguments, return_msg)

return ret
211 changes: 136 additions & 75 deletions skbase/utils/deep_equals.py → skbase/utils/deep_equals/_deep_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
pd.Series, pd.DataFrame, np.ndarray
lists, tuples, or dicts of a valid type (recursive)
"""
from inspect import isclass
from inspect import isclass, signature
from typing import List

from skbase.utils.deep_equals._common import _make_ret

__author__: List[str] = ["fkiraly"]
__all__: List[str] = ["deep_equals"]

Expand All @@ -27,11 +29,7 @@ def _softdep_available(importname):
return True


numpy_available = _softdep_available("numpy")
pandas_available = _softdep_available("pandas")


def deep_equals(x, y, return_msg=False):
def deep_equals(x, y, return_msg=False, plugins=None):
"""Test two objects for equality in value.

Correct if x/y are one of the following valid types:
Expand All @@ -49,6 +47,10 @@ def deep_equals(x, y, return_msg=False):
y : object
return_msg : bool, optional, default=False
whether to return informative message about what is not equal
plugins : list, optional, default=None
optional additional deep_equals plugins to use
will be appended to the default plugins from ``deep_equals_custom``
see ``deep_equals_custom`` for details of signature of plugins

Returns
-------
Expand All @@ -71,48 +73,20 @@ def deep_equals(x, y, return_msg=False):
[colname] - if pandas.DataFrame: column with name colname is not equal
!= - call to generic != returns False
"""
ret = _make_ret(return_msg)

if type(x) is not type(y):
return ret(False, f".type, x.type = {type(x)} != y.type = {type(y)}")

# we now know all types are the same
# so now we compare values

if numpy_available:
import numpy as np

# pandas is a soft dependency, so we compare pandas objects separately
# and only if pandas is installed in the environment
if _is_pandas(x) and pandas_available:
res = _pandas_equals(x, y, return_msg=return_msg)
if res is not None:
return _pandas_equals(x, y, return_msg=return_msg)

if numpy_available and _is_npndarray(x):
if x.dtype != y.dtype:
return ret(False, f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}")
return ret(np.array_equal(x, y, equal_nan=True), ".values")
# recursion through lists, tuples and dicts
elif isinstance(x, (list, tuple)):
return ret(*_tuple_equals(x, y, return_msg=True))
elif isinstance(x, dict):
return ret(*_dict_equals(x, y, return_msg=True))
elif _is_npnan(x):
return ret(_is_npnan(y), f"type(x)={type(x)} != type(y)={type(y)}")
elif isclass(x):
return ret(x == y, f".class, x={x.__name__} != y={y.__name__}")
elif type(x).__name__ == "ForecastingHorizon":
return ret(*_fh_equals(x, y, return_msg=True))
# this elif covers case where != is boolean
# some types return a vector upon !=, this is covered in the next elif
elif isinstance(x == y, bool):
return ret(x == y, f" !=, {x} != {y}")
# deal with the case where != returns a vector
elif numpy_available and np.any(x != y) or any(_coerce_list(x != y)):
return ret(False, f" !=, {x} != {y}")
# call deep_equals_custom with default plugins
plugins_default = [
_numpy_equals_plugin,
_pandas_equals_plugin,
_fh_equals_plugin,
]

if plugins is not None:
plugins_inner = plugins_default + plugins
else:
plugins_inner = plugins_default

return ret(True, "")
res = deep_equals_custom(x, y, return_msg=return_msg, plugins=plugins_inner)
return res


def _is_pandas(x):
Expand All @@ -131,6 +105,8 @@ def _is_npndarray(x):


def _is_npnan(x):
numpy_available = _softdep_available("numpy")

if numpy_available:
import numpy as np

Expand All @@ -150,7 +126,37 @@ def _coerce_list(x):
return x


def _pandas_equals(x, y, return_msg=False):
def _numpy_equals_plugin(x, y, return_msg=False):
numpy_available = _softdep_available("numpy")

if not numpy_available:
return None
else:
import numpy as np

ret = _make_ret(return_msg)

if _is_npndarray(x):
if x.dtype != y.dtype:
return ret(False, f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}")
return ret(np.array_equal(x, y, equal_nan=True), ".values")


def _pandas_equals_plugin(x, y, return_msg=False, deep_equals=None):
pandas_available = _softdep_available("pandas")

if not pandas_available:
return None

# pandas is a soft dependency, so we compare pandas objects separately
# and only if pandas is installed in the environment
if _is_pandas(x):
res = _pandas_equals(x, y, return_msg=return_msg, deep_equals=deep_equals)
if res is not None:
return res


def _pandas_equals(x, y, return_msg=False, deep_equals=None):
import pandas as pd

ret = _make_ret(return_msg)
Expand Down Expand Up @@ -291,7 +297,7 @@ def _dict_equals(x, y, return_msg=False):
return ret(True, "")


def _fh_equals(x, y, return_msg=False):
def _fh_equals_plugin(x, y, return_msg=False, deep_equals=None):
"""Test two forecasting horizons for equality.

Correct if both x and y are ForecastingHorizon
Expand All @@ -313,6 +319,9 @@ def _fh_equals(x, y, return_msg=False):
.is_relative - x is absolute and y is relative, or vice versa
.values - values of x and y are not equal
"""
if type(x).__name__ != "ForecastingHorizon":
return None

ret = _make_ret(return_msg)

if x.is_relative != y.is_relative:
Expand All @@ -326,40 +335,92 @@ def _fh_equals(x, y, return_msg=False):
return ret(True, "")


def _ret(is_equal, msg="", string_arguments: list = None, return_msg=False):
"""Return is_equal and msg, formatted with string_arguments if return_msg=True.
def deep_equals_custom(x, y, return_msg=False, plugins=None):
"""Test two objects for equality in value.

Correct if x/y are one of the following valid types:
types compatible with != comparison
pd.Series, pd.DataFrame, np.ndarray
lists, tuples, or dicts of a valid type (recursive)

Important note:
this function will return "not equal" if types of x,y are different
for instant, bool and numpy.bool are *not* considered equal

Parameters
----------
is_equal : bool
msg : str, optional, default=""
message to return if is_equal=False
string_arguments : list, optional, default=None
list of arguments to format msg with
x : object
y : object
return_msg : bool, optional, default=False
whether to return informative message about what is not equal
plugins : list, optional, default=None
list of plugins to use for custom deep_equals
entries must be functions with the signature:
``(x, y, return_msg: bool) -> return``
where return is:
``None``, if the plugin does not apply, otheriwse:
``is_equal: bool`` if ``return_msg=False``,
``(is_equal: bool, msg: str)`` if return_msg=True.
Plugins can have an additional argument ``deep_equals=None``
by which the parent function to be called recursively is passed

Returns
-------
is_equal : bool
identical to input ``is_equal``, always returned
msg : str, only returned if return_msg=True
if ``is_equal=True``, ``msg`` is always ``""``
if ``is_equal=False``, ``msg`` is formatted with ``string_arguments``
via ``msg.format(*string_arguments)``
is_equal: bool - True if x and y are equal in value
x and y do not need to be equal in reference
msg : str, only returned if return_msg = True
indication of what is the reason for not being equal
"""
if return_msg:
if is_equal:
msg = ""
elif isinstance(string_arguments, (list, tuple)) and len(string_arguments) > 0:
msg = msg.format(*string_arguments)
return is_equal, msg
else:
return is_equal
ret = _make_ret(return_msg)

if type(x) is not type(y):
return ret(False, f".type, x.type = {type(x)} != y.type = {type(y)}")

# we now know all types are the same
# so now we compare values

def _make_ret(return_msg):
"""Curry _ret with return_msg."""
# recursion through lists, tuples and dicts
if isinstance(x, (list, tuple)):
return ret(*_tuple_equals(x, y, return_msg=True))
elif isinstance(x, dict):
return ret(*_dict_equals(x, y, return_msg=True))
elif _is_npnan(x):
return ret(_is_npnan(y), f"type(x)={type(x)} != type(y)={type(y)}")
elif isclass(x):
return ret(x == y, f".class, x={x.__name__} != y={y.__name__}")

def ret(is_equal, msg, string_arguments=None):
return _ret(is_equal, msg, string_arguments, return_msg)
if plugins is not None:
for plugin in plugins:
# check if plugin has deep_equals argument
# if so, pass this function as argument to plugin
# this allows for recursive calls to deep_equals

# get the signature of the plugin
sig = signature(plugin)
# check if deep_equals is an argument of the plugin
if "deep_equals" in sig.parameters:
kwargs = {"deep_equals": deep_equals_custom}
else:
kwargs = {}

return ret
res = plugin(x, y, return_msg=return_msg, **kwargs)

# if plugin does not apply, res is None
if res is not None:
return res

# this if covers case where != is boolean
# some types return a vector upon !=, this is covered in the next elif
if isinstance(x == y, bool):
return ret(x == y, f" !=, {x} != {y}")

# check if numpy is available
numpy_available = _softdep_available("numpy")
if numpy_available:
import numpy as np

# deal with the case where != returns a vector
if numpy_available and np.any(x != y) or any(_coerce_list(x != y)):
return ret(False, f" !=, {x} != {y}")

return ret(True, "")