Skip to content

Commit

Permalink
[ENH] deep_equals - clearer return on pd.DataFrame dtypes not b…
Browse files Browse the repository at this point in the history
…eing equal (#246)

This PR improves error messaging in `deep_equals` - providing a clearer
return on `pd.DataFrame` `dtypes` no being equal.

Mirror of sktime/sktime#5560, for links to
`sktime` `dask` failures see there.
  • Loading branch information
fkiraly authored Dec 22, 2023
1 parent e87c07e commit ee7e589
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 4 deletions.
71 changes: 67 additions & 4 deletions skbase/utils/deep_equals/_deep_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _pandas_equals_plugin(x, y, return_msg=False, deep_equals=None):


def _pandas_equals(x, y, return_msg=False, deep_equals=None):
import numpy as np # pandas depends on numpy, so this import is fine
import pandas as pd

ret = _make_ret(return_msg)
Expand All @@ -173,13 +174,68 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None):
else:
return ret(x.equals(y), ".series_equals, x = {} != y = {}", [x, y])
elif isinstance(x, pd.DataFrame):
# check column names for equality
if not x.columns.equals(y.columns):
return ret(
False,
".columns, x.columns = {} != y.columns = {}",
[x.columns, y.columns],
False, f".columns, x.columns = {x.columns} != y.columns = {y.columns}"
)
# if columns are equal and at least one is object, recurse over Series
# check dtypes for equality
if not x.dtypes.equals(y.dtypes):
return ret(
False, f".dtypes, x.dtypes = {x.dtypes} != y.dtypes = {y.dtypes}"
)
# check index for equality
# we are not recursing due to ambiguity in integer index types
# which may differ from pandas version to pandas version
# and would upset the type check, e.g., RangeIndex(2) vs Index([0, 1])
xix = x.index
yix = y.index
if hasattr(xix, "dtype") and hasattr(xix, "dtype"):
if not xix.dtype == yix.dtype:
return ret(
False,
".index.dtype, x.index.dtype = {} != y.index.dtype = {}",
[xix.dtype, yix.dtype],
)
if hasattr(xix, "dtypes") and hasattr(yix, "dtypes"):
if not x.dtypes.equals(y.dtypes):
return ret(
False,
".index.dtypes, x.dtypes = {} != y.index.dtypes = {}",
[xix.dtypes, yix.dtypes],
)
ix_eq = xix.equals(yix)
if not ix_eq:
if not len(xix) == len(yix):
return ret(
False,
".index.len, x.index.len = {} != y.index.len = {}",
[len(xix), len(yix)],
)
if hasattr(xix, "name") and hasattr(yix, "name"):
if not xix.name == yix.name:
return ret(
False,
".index.name, x.index.name = {} != y.index.name = {}",
[xix.name, yix.name],
)
if hasattr(xix, "names") and hasattr(yix, "names"):
if not len(xix.names) == len(yix.names):
return ret(
False,
".index.names, x.index.names = {} != y.index.name = {}",
[xix.names, yix.names],
)
if not np.all(xix.names == yix.names):
return ret(
False,
".index.names, x.index.names = {} != y.index.name = {}",
[xix.names, yix.names],
)
elts_eq = np.all(xix == yix)
return ret(elts_eq, ".index.equals, x = {} != y = {}", [xix, yix])
# if columns, dtypes are equal and at least one is object, recurse over Series
if sum(x.dtypes == "object") > 0:
for c in x.columns:
is_equal, msg = deep_equals(x[c], y[c], return_msg=True)
Expand All @@ -189,7 +245,14 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None):
else:
return ret(x.equals(y), ".df_equals, x = {} != y = {}", [x, y])
elif isinstance(x, pd.Index):
return ret(x.equals(y), ".index_equals, x = {} != y = {}", [x, y])
if hasattr(x, "dtype") and hasattr(y, "dtype"):
if not x.dtype == y.dtype:
return ret(False, f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}")
if hasattr(x, "dtypes") and hasattr(y, "dtypes"):
if not x.dtypes.equals(y.dtypes):
return ret(
False, f".dtypes, x.dtypes = {x.dtypes} != y.dtypes = {y.dtypes}"
)
else:
raise RuntimeError(
f"Unexpected type of pandas object in _pandas_equals: type(x)={type(x)},"
Expand Down
1 change: 1 addition & 0 deletions skbase/utils/tests/test_deep_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
EXAMPLES += [
pd.DataFrame({"a": [4, 2]}),
pd.DataFrame({"a": [4, 3]}),
pd.DataFrame({"a": ["4", "3"]}),
(np.array([1, 2, 4]), [pd.DataFrame({"a": [4, 2]})]),
{"foo": [42], "bar": pd.Series([1, 2])},
{"bar": [42], "foo": pd.Series([1, 2])},
Expand Down

0 comments on commit ee7e589

Please sign in to comment.