From 60258c2c66dc92ad15816a4789a343be73f7440e Mon Sep 17 00:00:00 2001 From: ncassereau Date: Mon, 8 Nov 2021 10:59:25 +0100 Subject: [PATCH 1/5] solve bug --- ot/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ot/utils.py b/ot/utils.py index c87856338..0c7596952 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -221,7 +221,9 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2): if not get_backend(x1, x2).__name__ == 'numpy': raise NotImplementedError() else: - return cdist(x1, x2, metric=metric, p=p) + if metric.endswith("minkowski"): + return cdist(x1, x2, metric=metric, p=p) + return cdist(x1, x2, metric=metric) def dist0(n, method='lin_square'): From a20b09727d2c9c3fd43bbb7f075f23c68058f8f4 Mon Sep 17 00:00:00 2001 From: ncassereau Date: Mon, 8 Nov 2021 11:21:09 +0100 Subject: [PATCH 2/5] Weights & docs --- ot/utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ot/utils.py b/ot/utils.py index 0c7596952..e6c93c8bd 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -182,7 +182,7 @@ def euclidean_distances(X, Y, squared=False): return c -def dist(x1, x2=None, metric='sqeuclidean', p=2): +def dist(x1, x2=None, metric='sqeuclidean', p=2, w=None): r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` .. note:: This function is backend-compatible and will work on arrays @@ -202,6 +202,10 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2): 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'. + p : float, optional + p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2. + w : array-like, rank 1 + Weights for the weighted metrics. Returns @@ -222,8 +226,8 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2): raise NotImplementedError() else: if metric.endswith("minkowski"): - return cdist(x1, x2, metric=metric, p=p) - return cdist(x1, x2, metric=metric) + return cdist(x1, x2, metric=metric, p=p, w=w) + return cdist(x1, x2, metric=metric, w=w) def dist0(n, method='lin_square'): From d7e76884d6c27fbcefaa1c00091f42c167567e78 Mon Sep 17 00:00:00 2001 From: ncassereau Date: Mon, 8 Nov 2021 14:10:11 +0100 Subject: [PATCH 3/5] tests for dist --- test/test_utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/test_utils.py b/test/test_utils.py index 40f4e4926..d9e372ba6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -117,6 +117,22 @@ def test_dist(): np.testing.assert_allclose(D, D2, atol=1e-14) np.testing.assert_allclose(D, D3, atol=1e-14) + # tests that every metric runs correctly + metrics = [ + 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', + 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', + 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', + 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule' + ] + + for metric in metrics: + print(metric) + ot.dist(x, x, metric=metric, p=3, w=np.random.random((2, ))) + + # weighted minkowski but with no weights + with pytest.raises(ValueError): + ot.dist(x, x, metric="wminkowski") + def test_dist_backends(nx): From 2b12917db1a407bfdcec7fde02ff04366b7f5f78 Mon Sep 17 00:00:00 2001 From: ncassereau Date: Mon, 8 Nov 2021 14:25:23 +0100 Subject: [PATCH 4/5] test dist --- test/test_utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index d9e372ba6..cfcb1c664 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -118,16 +118,20 @@ def test_dist(): np.testing.assert_allclose(D, D3, atol=1e-14) # tests that every metric runs correctly - metrics = [ + metrics_w = [ 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', - 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', - 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', + 'euclidean', 'hamming', 'jaccard', 'kulsinski', + 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule' - ] + ] # those that support weights + metrics = ['mahalanobis', 'seuclidean'] # do not support weights depending on scipy's version - for metric in metrics: + for metric in metrics_w: print(metric) ot.dist(x, x, metric=metric, p=3, w=np.random.random((2, ))) + for metric in metrics: + print(metric) + ot.dist(x, x, metric=metric, p=3) # weighted minkowski but with no weights with pytest.raises(ValueError): From 4de4f54e4dfb496a890c49665a32286b7b9f47e7 Mon Sep 17 00:00:00 2001 From: ncassereau Date: Mon, 8 Nov 2021 14:27:35 +0100 Subject: [PATCH 5/5] pep8 --- test/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index cfcb1c664..6b476b2af 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -118,7 +118,7 @@ def test_dist(): np.testing.assert_allclose(D, D3, atol=1e-14) # tests that every metric runs correctly - metrics_w = [ + metrics_w = [ 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', @@ -126,7 +126,7 @@ def test_dist(): ] # those that support weights metrics = ['mahalanobis', 'seuclidean'] # do not support weights depending on scipy's version - for metric in metrics_w: + for metric in metrics_w: print(metric) ot.dist(x, x, metric=metric, p=3, w=np.random.random((2, ))) for metric in metrics: