Skip to content

Commit

Permalink
[ENH] improved deep_equals utility - plugins for custom types (#238)
Browse files Browse the repository at this point in the history
This PR improves the `deep_equals` utility by introducing the ability to
provide custom plugins.

Ultimately, this enables more flexibility in soft dependencies for data
containers such as `polars` (see, e.g.,
sktime/sktime#5423) throughout the `sktime`
package ecosystem, as `deep_equals` is used for extrinsic testing and
validity checking.
Similarly, the `ForecastingHorizon` check should be moved as a plugin
into `sktime`.

In addition, makes the following changes:

* splits up the `deep_equals` file into a module with multiple files

Depends on #239.
  • Loading branch information
fkiraly authored Oct 26, 2023
1 parent d55abc6 commit e2271b5
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 87 deletions.
14 changes: 10 additions & 4 deletions skbase/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
"skbase.utils._nested_iter",
"skbase.utils._utils",
"skbase.utils.deep_equals",
"skbase.utils.deep_equals._common",
"skbase.utils.deep_equals._deep_equals",
"skbase.utils.dependencies",
"skbase.utils.dependencies._dependencies",
"skbase.validate",
Expand Down Expand Up @@ -160,6 +162,7 @@
),
"skbase.utils._utils": ("subset_dict_keys",),
"skbase.utils.deep_equals": ("deep_equals",),
"skbase.utils.deep_equals._deep_equals": ("deep_equals", "deep_equals_custom"),
"skbase.validate._types": ("check_sequence", "check_type", "is_sequence"),
}
SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy()
Expand Down Expand Up @@ -207,19 +210,22 @@
"unflatten",
),
"skbase.utils._utils": ("subset_dict_keys",),
"skbase.utils.deep_equals": (
"skbase.utils.deep_equals": ("deep_equals",),
"skbase.utils.deep_equals._common": ("_make_ret", "_ret"),
"skbase.utils.deep_equals._deep_equals": (
"_coerce_list",
"_dict_equals",
"_fh_equals",
"_fh_equals_plugin",
"_is_npnan",
"_is_npndarray",
"_is_pandas",
"_make_ret",
"_numpy_equals_plugin",
"_pandas_equals",
"_ret",
"_pandas_equals_plugin",
"_softdep_available",
"_tuple_equals",
"deep_equals",
"deep_equals_custom",
),
"skbase.utils.dependencies._dependencies": (
"_check_soft_dependencies",
Expand Down
8 changes: 8 additions & 0 deletions skbase/utils/deep_equals/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
"""Module for nested equality checking."""
from skbase.utils.deep_equals._deep_equals import deep_equals

__all__ = [
"deep_equals",
]
41 changes: 41 additions & 0 deletions skbase/utils/deep_equals/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
"""Common utility functions for skbase.utils.deep_equals."""


def _ret(is_equal, msg="", string_arguments: list = None, return_msg=False):
"""Return is_equal and msg, formatted with string_arguments if return_msg=True.
Parameters
----------
is_equal : bool
msg : str, optional, default=""
message to return if is_equal=False
string_arguments : list, optional, default=None
list of arguments to format msg with
Returns
-------
is_equal : bool
identical to input ``is_equal``, always returned
msg : str, only returned if return_msg=True
if ``is_equal=True``, ``msg`` is always ``""``
if ``is_equal=False``, ``msg`` is formatted with ``string_arguments``
via ``msg.format(*string_arguments)``
"""
if return_msg:
if is_equal:
msg = ""
elif isinstance(string_arguments, (list, tuple)) and len(string_arguments) > 0:
msg = msg.format(*string_arguments)
return is_equal, msg
else:
return is_equal


def _make_ret(return_msg):
"""Curry _ret with return_msg."""

def ret(is_equal, msg, string_arguments=None):
return _ret(is_equal, msg, string_arguments, return_msg)

return ret
Loading

0 comments on commit e2271b5

Please sign in to comment.