Skip to content

Commit

Permalink
Reproduce old np.array_equal behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
thequilo committed Aug 23, 2024
1 parent e3dbea0 commit fc407e8
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions sacred/config/custom_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,19 @@ def type_changed(old_value, new_value):
def is_different(old_value, new_value):
"""Numpy aware comparison between two values."""
if opt.has_numpy:
# np.array_equal raises an exception when the arguments are not array in numpy 2.0.
# This issue is only present in 2.0, not in <2.0 or >=2.1
if isinstance(old_value, opt.np.ndarray) and isinstance(new_value, opt.np.ndarray):
return not opt.np.array_equal(old_value, new_value)
elif isinstance(old_value, opt.np.ndarray) or isinstance(new_value, opt.np.ndarray):
# Reproduces np.array_equal from numpy<2
# np.array_equal raises an exception when the arguments are scalar and
# differ in type (e.g. int and str) in numpy>=2.0
try:
old_value = opt.np.asarray(old_value)
new_value = opt.np.asarray(new_value)
except:
return False
else:
result = old_value == new_value
if isinstance(result, bool):
return result
else:
return result.all()

return old_value != new_value

0 comments on commit fc407e8

Please sign in to comment.