Skip to content

Commit

Permalink
test scalar output
Browse files Browse the repository at this point in the history
  • Loading branch information
TonyBagnall committed Feb 3, 2025
1 parent 1e0942a commit 7555f8a
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion aeon/testing/estimator_checking/_yield_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,8 @@ def _equal_outputs(output1, output2):
"""Test whether two outputs from an estimator are logically identical.
Valid data structures are:
1. numpy array: stores an equal length collection or series
1. float
2. numpy array: stores an equal length collection or series
2. dict: a histogram of counts, usually of discrete series
3. pd.DataFrame: series stored in dataframe
4. list: stores unequal length series in a format 1-3
Expand All @@ -618,6 +619,8 @@ def _equal_outputs(output1, output2):
"""
if type(output1) is not type(output2):
return False
if np.issubdtype(type(output1), np.floating):
return np.isclose(output1, output2)
if isinstance(output1, np.ndarray): # 1. X an equal length collection or series
return np.allclose(output1, output2, equal_nan=True)
if isinstance(output1, dict): # 2. X a dictionary, dense collection or series
Expand Down

0 comments on commit 7555f8a

Please sign in to comment.