From daf5d705492126160b2a920e7b691595c784cfb8 Mon Sep 17 00:00:00 2001 From: "Edwin (Ed) Onuonga" Date: Sat, 13 Apr 2024 17:12:28 +0100 Subject: [PATCH] fix: call `KNNMixin._dtw1d` when `independent=True` (#251) --- sequentia/models/knn/base.py | 2 +- tests/unit/test_models/knn/test_classifier.py | 20 ++++++++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/sequentia/models/knn/base.py b/sequentia/models/knn/base.py index f60687c..d2d91e9 100644 --- a/sequentia/models/knn/base.py +++ b/sequentia/models/knn/base.py @@ -206,7 +206,7 @@ def _dtwi(self: KNNMixin, A: FloatArray, B: FloatArray) -> float: def dtw(a: FloatArray, b: FloatArray) -> float: """Windowed DTW wrapper function.""" - return self._dtw(a, b, window=window) + return self._dtw1d(a, b, window=window) return np.sum([dtw(A[:, i], B[:, i]) for i in range(A.shape[1])]) diff --git a/tests/unit/test_models/knn/test_classifier.py b/tests/unit/test_models/knn/test_classifier.py index ce3b0e8..15f4544 100644 --- a/tests/unit/test_models/knn/test_classifier.py +++ b/tests/unit/test_models/knn/test_classifier.py @@ -46,18 +46,32 @@ def assert_fit(clf: KNNClassifier, /, *, data: SequentialDataset) -> None: @pytest.mark.parametrize("k", [1, 2, 5]) @pytest.mark.parametrize("weighting", [None, lambda x: np.exp(-x)]) +@pytest.mark.parametrize("independent", [False, True]) def test_classifier_e2e( helpers: t.Any, request: SubRequest, - k: int, - weighting: t.Callable | None, dataset: SequentialDataset, random_state: np.random.RandomState, + *, + k: int, + weighting: t.Callable | None, + independent: bool, ) -> None: - clf = KNNClassifier(k=k, weighting=weighting, random_state=random_state) + clf = KNNClassifier( + k=k, + weighting=weighting, + independent=independent, + random_state=random_state, + ) assert clf.k == k assert clf.weighting == weighting + assert clf.independent == independent + + if independent: + assert clf._dtw().__name__ == "_dtwi" + else: + assert clf._dtw().__name__ == "_dtwd" data = dataset.copy() data._X = data._X[:, :1] # only use one feature