Skip to content

Commit

Permalink
Fix equality check for np arrays with NaNs (#1343)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1343

Prior to this change, having NaN in a field would lead to the equality check failing. This is a real issue with observation data with multi-objective as the non-diagonal entries of covariance matrix get populated with NaNs.

Reviewed By: Balandat

Differential Revision: D42254565

fbshipit-source-id: 8798b89bd268def947e426d0b65859368d51b5a4
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Dec 28, 2022
1 parent e56ed36 commit 7b674bc
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ax/utils/common/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def object_attribute_dicts_find_unequal_fields(
list(one_val.values()), list(other_val.values())
)
elif isinstance(one_val, np.ndarray):
equal = np.array_equal(one_val, other_val)
equal = np.array_equal(one_val, other_val, equal_nan=True)
elif isinstance(one_val, datetime):
equal = datetime_equals(one_val, other_val)
elif isinstance(one_val, float):
Expand Down
26 changes: 26 additions & 0 deletions ax/utils/common/tests/test_equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

from datetime import datetime

import numpy as np
import pandas as pd
from ax.utils.common.equality import (
dataframe_equals,
datetime_equals,
equality_typechecker,
object_attribute_dicts_find_unequal_fields,
same_elements,
)
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -46,3 +48,27 @@ def testDataframeEquals(self) -> None:
self.assertTrue(dataframe_equals(pd.DataFrame(), pd.DataFrame()))
self.assertTrue(dataframe_equals(pd1, pd2))
self.assertFalse(dataframe_equals(pd1, pd3))

def test_numpy_equals(self) -> None:
# Simple check.
np_0 = {"cov": np.array([[0.1, 0.0], [0.0, 0.1]])}
np_1 = {"cov": np.array([[0.1, 0.0], [0.0, 0.1]])}
self.assertEqual(
object_attribute_dicts_find_unequal_fields(np_0, np_1), ({}, {})
)
# Unequal.
np_1 = {"cov": np.array([[0.1, 0.0], [0.1, 0.1]])}
self.assertEqual(
object_attribute_dicts_find_unequal_fields(np_0, np_1),
({}, {"cov": (np_0["cov"], np_1["cov"])}),
)
# With NaNs.
np_1 = {"cov": np.array([[0.1, float("nan")], [float("nan"), 0.1]])}
self.assertEqual(
object_attribute_dicts_find_unequal_fields(np_0, np_1),
({}, {"cov": (np_0["cov"], np_1["cov"])}),
)
np_0 = {"cov": np.array([[0.1, float("nan")], [float("nan"), 0.1]])}
self.assertEqual(
object_attribute_dicts_find_unequal_fields(np_0, np_1), ({}, {})
)

0 comments on commit 7b674bc

Please sign in to comment.