Skip to content

Commit

Permalink
equal nans
Browse files Browse the repository at this point in the history
  • Loading branch information
TonyBagnall committed Feb 3, 2025
1 parent 87b6544 commit 1e0942a
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions aeon/testing/estimator_checking/_yield_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def check_raises_not_fitted_error(estimator, datatype):
def _equal_outputs(output1, output2):
"""Test whether two outputs from an estimator are logically identical.
Valid data strutures are:
Valid data structures are:
1. numpy array: stores an equal length collection or series
2. dict: a histogram of counts, usually of discrete series
3. pd.DataFrame: series stored in dataframe
Expand All @@ -619,7 +619,7 @@ def _equal_outputs(output1, output2):
if type(output1) is not type(output2):
return False
if isinstance(output1, np.ndarray): # 1. X an equal length collection or series
return np.allclose(output1, output2)
return np.allclose(output1, output2, equal_nan=True)
if isinstance(output1, dict): # 2. X a dictionary, dense collection or series
if output1.keys() != output2.keys():
return False
Expand All @@ -629,7 +629,7 @@ def _equal_outputs(output1, output2):
return True
if isinstance(output1, pd.DataFrame) or isinstance(output1, pd.Series):
# 3. X a dataframe
return np.allclose(output1.values, output2.values)
return np.allclose(output1.values, output2.values, equal_nan=True)
if isinstance(output1, list): # X a possibly unequal length collection
if len(output1) != len(output2):
return False
Expand Down Expand Up @@ -670,7 +670,8 @@ def check_persistence_via_pickle(estimator, datatype):
if not _equal_outputs(output, results[i]):
raise ValueError(
f"Running {method} after serialisation parameters gives "
f"different results."
f"different results. "
f"{type(estimator)} returns data as {type(output)}"
)
i += 1

Expand All @@ -696,7 +697,8 @@ def check_fit_deterministic(estimator, datatype):
output = _run_estimator_method(estimator, method, datatype, "test")
if not _equal_outputs(output, results[i]):
raise ValueError(
f"Running {method} after fit twice with test "
f"parameters gives different results."
f"Running {method} with test parameters after two calls to fit "
f"gives different results."
f"{type(estimator)} returns data as {type(output)}"
)
i += 1

0 comments on commit 1e0942a

Please sign in to comment.