diff --git a/ax/utils/common/equality.py b/ax/utils/common/equality.py index fe2830c2ada..2d433a6dbf3 100644 --- a/ax/utils/common/equality.py +++ b/ax/utils/common/equality.py @@ -183,7 +183,7 @@ def object_attribute_dicts_find_unequal_fields( list(one_val.values()), list(other_val.values()) ) elif isinstance(one_val, np.ndarray): - equal = np.array_equal(one_val, other_val) + equal = np.array_equal(one_val, other_val, equal_nan=True) elif isinstance(one_val, datetime): equal = datetime_equals(one_val, other_val) elif isinstance(one_val, float): diff --git a/ax/utils/common/tests/test_equality.py b/ax/utils/common/tests/test_equality.py index e0aaabc0db0..57f1da89b0d 100644 --- a/ax/utils/common/tests/test_equality.py +++ b/ax/utils/common/tests/test_equality.py @@ -6,11 +6,13 @@ from datetime import datetime +import numpy as np import pandas as pd from ax.utils.common.equality import ( dataframe_equals, datetime_equals, equality_typechecker, + object_attribute_dicts_find_unequal_fields, same_elements, ) from ax.utils.common.testutils import TestCase @@ -46,3 +48,27 @@ def testDataframeEquals(self) -> None: self.assertTrue(dataframe_equals(pd.DataFrame(), pd.DataFrame())) self.assertTrue(dataframe_equals(pd1, pd2)) self.assertFalse(dataframe_equals(pd1, pd3)) + + def test_numpy_equals(self) -> None: + # Simple check. + np_0 = {"cov": np.array([[0.1, 0.0], [0.0, 0.1]])} + np_1 = {"cov": np.array([[0.1, 0.0], [0.0, 0.1]])} + self.assertEqual( + object_attribute_dicts_find_unequal_fields(np_0, np_1), ({}, {}) + ) + # Unequal. + np_1 = {"cov": np.array([[0.1, 0.0], [0.1, 0.1]])} + self.assertEqual( + object_attribute_dicts_find_unequal_fields(np_0, np_1), + ({}, {"cov": (np_0["cov"], np_1["cov"])}), + ) + # With NaNs. + np_1 = {"cov": np.array([[0.1, float("nan")], [float("nan"), 0.1]])} + self.assertEqual( + object_attribute_dicts_find_unequal_fields(np_0, np_1), + ({}, {"cov": (np_0["cov"], np_1["cov"])}), + ) + np_0 = {"cov": np.array([[0.1, float("nan")], [float("nan"), 0.1]])} + self.assertEqual( + object_attribute_dicts_find_unequal_fields(np_0, np_1), ({}, {}) + )