Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] safer comparison in deep_equals if np.any(x != y) does not result in boolean #323

Merged
merged 5 commits into from
May 9, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions skbase/utils/deep_equals/_deep_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,18 +503,50 @@
if isinstance(x == y, bool):
return ret(x == y, f" !=, {x} != {y}")

# check if numpy is available
numpy_available = _softdep_available("numpy")
if numpy_available:
import numpy as np

# deal with the case where != returns a vector
if numpy_available and np.any(x != y) or np.any(_coerce_list(x != y)):
if _safe_any_unequal(x, y):
return ret(False, f" !=, {x} != {y}")

return ret(True, "")


def _safe_any_unequal(x, y):
"""Return whether any of x != y, if != results in iterable, False on exception.

Written very defensively to avoid exceptions, as exceptions may be raised
since any(x != y) or the safer np.any(x != y) may not be boolean,
e.g., in pathological cases of nested objects.
"""
try:
unequal = x != y
except Exception:
Dismissed Show dismissed Hide dismissed
return False

# check if numpy is available
numpy_available = _softdep_available("numpy")

if not numpy_available:
try:
any_un = any(unequal)
if isinstance(any_un, bool):
return any_un
else:
return False
except Exception:
return False

import numpy as np

try:
any_un = np.any(x != y) or np.any(_coerce_list(x != y))
if isinstance(any_un, bool):
return any_un
else:
return False
except Exception:
return False


def _safe_len(x):
"""Return length of x if len(x) does not result in exception, else -1."""
if hasattr(x, "__len__"):
Expand Down
Loading