diff --git a/skbase/utils/deep_equals/_deep_equals.py b/skbase/utils/deep_equals/_deep_equals.py index 425408ec..32c3bc51 100644 --- a/skbase/utils/deep_equals/_deep_equals.py +++ b/skbase/utils/deep_equals/_deep_equals.py @@ -150,6 +150,7 @@ def _pandas_equals_plugin(x, y, return_msg=False, deep_equals=None): def _pandas_equals(x, y, return_msg=False, deep_equals=None): + import numpy as np # pandas depends on numpy, so this import is fine import pandas as pd ret = _make_ret(return_msg) @@ -173,13 +174,68 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None): else: return ret(x.equals(y), ".series_equals, x = {} != y = {}", [x, y]) elif isinstance(x, pd.DataFrame): + # check column names for equality if not x.columns.equals(y.columns): return ret( - False, - ".columns, x.columns = {} != y.columns = {}", - [x.columns, y.columns], + False, f".columns, x.columns = {x.columns} != y.columns = {y.columns}" ) # if columns are equal and at least one is object, recurse over Series + # check dtypes for equality + if not x.dtypes.equals(y.dtypes): + return ret( + False, f".dtypes, x.dtypes = {x.dtypes} != y.dtypes = {y.dtypes}" + ) + # check index for equality + # we are not recursing due to ambiguity in integer index types + # which may differ from pandas version to pandas version + # and would upset the type check, e.g., RangeIndex(2) vs Index([0, 1]) + xix = x.index + yix = y.index + if hasattr(xix, "dtype") and hasattr(xix, "dtype"): + if not xix.dtype == yix.dtype: + return ret( + False, + ".index.dtype, x.index.dtype = {} != y.index.dtype = {}", + [xix.dtype, yix.dtype], + ) + if hasattr(xix, "dtypes") and hasattr(yix, "dtypes"): + if not x.dtypes.equals(y.dtypes): + return ret( + False, + ".index.dtypes, x.dtypes = {} != y.index.dtypes = {}", + [xix.dtypes, yix.dtypes], + ) + ix_eq = xix.equals(yix) + if not ix_eq: + if not len(xix) == len(yix): + return ret( + False, + ".index.len, x.index.len = {} != y.index.len = {}", + [len(xix), len(yix)], + ) + if hasattr(xix, "name") and hasattr(yix, "name"): + if not xix.name == yix.name: + return ret( + False, + ".index.name, x.index.name = {} != y.index.name = {}", + [xix.name, yix.name], + ) + if hasattr(xix, "names") and hasattr(yix, "names"): + if not len(xix.names) == len(yix.names): + return ret( + False, + ".index.names, x.index.names = {} != y.index.name = {}", + [xix.names, yix.names], + ) + if not np.all(xix.names == yix.names): + return ret( + False, + ".index.names, x.index.names = {} != y.index.name = {}", + [xix.names, yix.names], + ) + elts_eq = np.all(xix == yix) + return ret(elts_eq, ".index.equals, x = {} != y = {}", [xix, yix]) + # if columns, dtypes are equal and at least one is object, recurse over Series if sum(x.dtypes == "object") > 0: for c in x.columns: is_equal, msg = deep_equals(x[c], y[c], return_msg=True) @@ -189,7 +245,14 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None): else: return ret(x.equals(y), ".df_equals, x = {} != y = {}", [x, y]) elif isinstance(x, pd.Index): - return ret(x.equals(y), ".index_equals, x = {} != y = {}", [x, y]) + if hasattr(x, "dtype") and hasattr(y, "dtype"): + if not x.dtype == y.dtype: + return ret(False, f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}") + if hasattr(x, "dtypes") and hasattr(y, "dtypes"): + if not x.dtypes.equals(y.dtypes): + return ret( + False, f".dtypes, x.dtypes = {x.dtypes} != y.dtypes = {y.dtypes}" + ) else: raise RuntimeError( f"Unexpected type of pandas object in _pandas_equals: type(x)={type(x)}," diff --git a/skbase/utils/tests/test_deep_equals.py b/skbase/utils/tests/test_deep_equals.py index 271cf135..d8a3b219 100644 --- a/skbase/utils/tests/test_deep_equals.py +++ b/skbase/utils/tests/test_deep_equals.py @@ -36,6 +36,7 @@ EXAMPLES += [ pd.DataFrame({"a": [4, 2]}), pd.DataFrame({"a": [4, 3]}), + pd.DataFrame({"a": ["4", "3"]}), (np.array([1, 2, 4]), [pd.DataFrame({"a": [4, 2]})]), {"foo": [42], "bar": pd.Series([1, 2])}, {"bar": [42], "foo": pd.Series([1, 2])},