Skip to content

Commit

Permalink
Merge branch 'main' into ajb/testing
Browse files Browse the repository at this point in the history
  • Loading branch information
TonyBagnall committed Feb 3, 2025
2 parents f301acf + 874478c commit ce75b85
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 24 deletions.
37 changes: 19 additions & 18 deletions aeon/testing/estimator_checking/_yield_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import joblib
import numpy as np
from numpy.testing import assert_array_almost_equal
from sklearn.exceptions import NotFittedError

from aeon.anomaly_detection.base import BaseAnomalyDetector
Expand Down Expand Up @@ -625,20 +624,19 @@ def check_persistence_via_pickle(estimator, datatype):
for method in NON_STATE_CHANGING_METHODS_ARRAYLIKE:
if hasattr(estimator, method) and callable(getattr(estimator, method)):
output = _run_estimator_method(estimator, method, datatype, "test")
assert_array_almost_equal(
output,
results[i],
err_msg=f"Running {method} after fit twice with test "
f"parameters gives different results.",
)
same, msg = deep_equals(output, results[i], return_msg=True)
if not same:
raise ValueError(
f"Running {method} after serialisation parameters gives "
f"different results. "
f"{type(estimator)} returns data as {type(output)}: test "
f"equivalence message: {msg}"
)
i += 1


def check_fit_deterministic(estimator, datatype):
"""Test that fit is deterministic.
Check that calling fit twice is equivalent to calling it once.
"""
"""Check that calling fit twice is equivalent to calling it once."""
estimator = _clone_estimator(estimator, random_state=0)
_run_estimator_method(estimator, "fit", datatype, "train")

Expand All @@ -648,17 +646,20 @@ def check_fit_deterministic(estimator, datatype):
output = _run_estimator_method(estimator, method, datatype, "test")
results.append(output)

# run fit and other methods a second time
# run fit a second time
_run_estimator_method(estimator, "fit", datatype, "train")

# check output of predict/transform etc does not change
i = 0
for method in NON_STATE_CHANGING_METHODS_ARRAYLIKE:
if hasattr(estimator, method) and callable(getattr(estimator, method)):
output = _run_estimator_method(estimator, method, datatype, "test")
assert_array_almost_equal(
output,
results[i],
err_msg=f"Running {method} after fit twice with test "
f"parameters gives different results.",
)
same, msg = deep_equals(output, results[i], return_msg=True)
if not same:
raise ValueError(
f"Running {method} with test parameters after two calls to fit "
f"gives different results."
f"{type(estimator)} returns data as {type(output)}: test "
f"equivalence message: {msg}"
)
i += 1
4 changes: 2 additions & 2 deletions aeon/testing/testing_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
NUMBA_DISABLED = os.environ.get("NUMBA_DISABLE_JIT") == "1"

# exclude estimators here for short term fixes
EXCLUDE_ESTIMATORS = ["REDCOMETS"]
# Hydra excluded because it returns a pytorch Tensor
EXCLUDE_ESTIMATORS = ["REDCOMETS", "HydraTransformer"]

# Exclude specific tests for estimators here
EXCLUDED_TESTS = {
Expand All @@ -49,7 +50,6 @@
"RSASTClassifier": ["check_fit_deterministic"],
"SAST": ["check_fit_deterministic"],
"RSAST": ["check_fit_deterministic"],
"SFA": ["check_persistence_via_pickle", "check_fit_deterministic"],
# missed in legacy testing, changes state in predict/transform
"HMMSegmenter": ["check_non_state_changing_method"],
"RSTSF": ["check_non_state_changing_method"],
Expand Down
10 changes: 6 additions & 4 deletions aeon/testing/utils/deep_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _deep_equals(x, y, depth, ignore_index):
eq = np.isnan(y)
msg = "" if eq else f"x ({x}) != y ({y}), depth={depth}"
return eq, msg
elif isinstance(x == y, bool):
elif isinstance(x == y, (bool, np.bool_)):
eq = x == y
msg = "" if eq else f"x ({x}) != y ({y}), depth={depth}"
return eq, msg
Expand Down Expand Up @@ -131,9 +131,11 @@ def _dataframe_equals(x, y, depth, ignore_index):
def _numpy_equals(x, y, depth):
if x.dtype != y.dtype:
return False, f"x.dtype ({x.dtype}) != y.dtype ({y.dtype})"

eq = np.allclose(x, y, equal_nan=True)
msg = "" if eq else f"x ({x}) != y ({y}), depth={depth}"
if x.dtype == "object":
eq, msg = _deep_equals(x.tolist(), y.tolist(), depth, ignore_index=True)
else:
eq = np.allclose(x, y, equal_nan=True)
msg = "" if eq else f"x ({x}) != y ({y}), depth={depth}"
return eq, msg


Expand Down

0 comments on commit ce75b85

Please sign in to comment.