From a2591c7e17ba4fa20048da100b2e1eaedce40893 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Fri, 13 Oct 2023 18:32:53 -0300 Subject: [PATCH 01/12] [FEAT] Add the parameter 'type_as' to the backends --- ot/backend.py | 50 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index e9750ee0a..ae3d11456 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -86,15 +86,15 @@ # # License: MIT License -import numpy as np import os -import scipy -import scipy.linalg -from scipy.sparse import issparse, coo_matrix, csr_matrix -import scipy.special as special import time import warnings +import numpy as np +import scipy +import scipy.linalg +import scipy.special as special +from scipy.sparse import coo_matrix, csr_matrix, issparse DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH' DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX' @@ -650,7 +650,7 @@ def std(self, a, axis=None): """ raise NotImplementedError() - def linspace(self, start, stop, num): + def linspace(self, start, stop, num, type_as=None): r""" Returns a specified number of evenly spaced values over a given interval. @@ -1208,8 +1208,11 @@ def median(self, a, axis=None): def std(self, a, axis=None): return np.std(a, axis=axis) - def linspace(self, start, stop, num): - return np.linspace(start, stop, num) + def linspace(self, start, stop, num, type_as=None): + if type_as is None: + return np.linspace(start, stop, num) + else: + return np.linspace(start, stop, num, dtype=type_as.dtype) def meshgrid(self, a, b): return np.meshgrid(a, b) @@ -1579,8 +1582,11 @@ def median(self, a, axis=None): def std(self, a, axis=None): return jnp.std(a, axis=axis) - def linspace(self, start, stop, num): - return jnp.linspace(start, stop, num) + def linspace(self, start, stop, num, type_as=None): + if type_as is None: + return jnp.linspace(start, stop, num) + else: + return self._change_device(jnp.linspace(start, stop, num, dtype=type_as.dtype), type_as) def meshgrid(self, a, b): return jnp.meshgrid(a, b) @@ -1986,6 +1992,7 @@ def concatenate(self, arrays, axis=0): def zero_pad(self, a, pad_width, value=0): from torch.nn.functional import pad + # pad_width is an array of ndim tuples indicating how many 0 before and after # we need to add. We first need to make it compliant with torch syntax, that # starts with the last dim, then second last, etc. @@ -2006,6 +2013,7 @@ def mean(self, a, axis=None): def median(self, a, axis=None): from packaging import version + # Since version 1.11.0, interpolation is available if version.parse(torch.__version__) >= version.parse("1.11.0"): if axis is not None: @@ -2026,8 +2034,11 @@ def std(self, a, axis=None): else: return torch.std(a, unbiased=False) - def linspace(self, start, stop, num): - return torch.linspace(start, stop, num, dtype=torch.float64) + def linspace(self, start, stop, num, type_as=None): + if type_as is None: + return torch.linspace(start, stop, num, dtype=torch.float64) + else: + return torch.linspace(start, stop, num, dtype=torch.float64, device=type_as.device) def meshgrid(self, a, b): try: @@ -2427,8 +2438,12 @@ def median(self, a, axis=None): def std(self, a, axis=None): return cp.std(a, axis=axis) - def linspace(self, start, stop, num): - return cp.linspace(start, stop, num) + def linspace(self, start, stop, num, type_as=None): + if type_as is None: + return cp.linspace(start, stop, num) + else: + with cp.cuda.Device(type_as.device): + return cp.linspace(start, stop, num, dtype=type_as.dtype) def meshgrid(self, a, b): return cp.meshgrid(a, b) @@ -2834,8 +2849,11 @@ def median(self, a, axis=None): def std(self, a, axis=None): return tnp.std(a, axis=axis) - def linspace(self, start, stop, num): - return tnp.linspace(start, stop, num) + def linspace(self, start, stop, num, type_as=None): + if type_as is None: + return tnp.linspace(start, stop, num) + else: + return tnp.linspace(start, stop, num, dtype=type_as.dtype) def meshgrid(self, a, b): return tnp.meshgrid(a, b) From 2afd2317858df27a3e66eb8c525999039e4f58ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Fri, 13 Oct 2023 18:33:11 -0300 Subject: [PATCH 02/12] [TEST] add tests for the 'type_as' backends --- test/test_backend.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/test/test_backend.py b/test/test_backend.py index 8ab861078..605e30ad8 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -6,16 +6,13 @@ # # License: MIT License -import ot -import ot.backend -from ot.backend import torch, jax, tf - -import pytest - import numpy as np +import pytest from numpy.testing import assert_array_almost_equal_nulp -from ot.backend import get_backend, get_backend_list, to_numpy +import ot +import ot.backend +from ot.backend import get_backend, get_backend_list, jax, tf, to_numpy, torch def test_get_backend_list(): @@ -507,6 +504,7 @@ def test_func_backends(nx): lst_name.append('std') A = nx.linspace(0, 1, 50) + A = nx.linspace(0, 1, 50, type_as=Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('linspace') From 5e81fefbca455e43addb8ff5823b0b49fa44d52c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Fri, 13 Oct 2023 19:21:46 -0300 Subject: [PATCH 03/12] [DEBUG] Debug dtype in pytorch --- ot/backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index ae3d11456..80eb7604e 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -2036,9 +2036,9 @@ def std(self, a, axis=None): def linspace(self, start, stop, num, type_as=None): if type_as is None: - return torch.linspace(start, stop, num, dtype=torch.float64) + return torch.linspace(start, stop, num) else: - return torch.linspace(start, stop, num, dtype=torch.float64, device=type_as.device) + return torch.linspace(start, stop, num, dtype=type_as.dtype, device=type_as.device) def meshgrid(self, a, b): try: From f4f2d4409455daa636ae756abf93643e012424f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Fri, 13 Oct 2023 19:22:33 -0300 Subject: [PATCH 04/12] [FIX] Add type_as every time linspace is called --- ot/bregman.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 29bcd5867..c90d89986 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -20,7 +20,8 @@ import numpy as np from scipy.optimize import fmin_l_bfgs_b -from ot.utils import unif, dist, list_to_array +from ot.utils import dist, list_to_array, unif + from .backend import get_backend @@ -2217,11 +2218,11 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, # build the convolution operator # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, A.shape[1]) + t = nx.linspace(0, 1, A.shape[1], type_as=A) [Y, X] = nx.meshgrid(t, t) K1 = nx.exp(-(X - Y) ** 2 / reg) - t = nx.linspace(0, 1, A.shape[2]) + t = nx.linspace(0, 1, A.shape[2], type_as=A) [Y, X] = nx.meshgrid(t, t) K2 = nx.exp(-(X - Y) ** 2 / reg) @@ -2295,11 +2296,11 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, err = 1 # build the convolution operator # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, width) + t = nx.linspace(0, 1, width, type_as=A) [Y, X] = nx.meshgrid(t, t) M1 = - (X - Y) ** 2 / reg - t = nx.linspace(0, 1, height) + t = nx.linspace(0, 1, height, type_as=A) [Y, X] = nx.meshgrid(t, t) M2 = - (X - Y) ** 2 / reg @@ -2452,11 +2453,11 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, # build the convolution operator # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, width) + t = nx.linspace(0, 1, width, type_as=A) [Y, X] = nx.meshgrid(t, t) K1 = nx.exp(-(X - Y) ** 2 / reg) - t = nx.linspace(0, 1, height) + t = nx.linspace(0, 1, height, type_as=A) [Y, X] = nx.meshgrid(t, t) K2 = nx.exp(-(X - Y) ** 2 / reg) @@ -2532,11 +2533,11 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10 err = 1 # build the convolution operator # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, width) + t = nx.linspace(0, 1, width, type_as=A) [Y, X] = nx.meshgrid(t, t) M1 = - (X - Y) ** 2 / reg - t = nx.linspace(0, 1, height) + t = nx.linspace(0, 1, height, type_as=A) [Y, X] = nx.meshgrid(t, t) M2 = - (X - Y) ** 2 / reg From 6a9b8f90e33deecc92967fabb03631b32eed85c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Fri, 13 Oct 2023 19:23:19 -0300 Subject: [PATCH 05/12] [TEST] Add test for the convolutional_barycenter2d algorithms (they are the only ones use linspace) --- test/test_bregman.py | 263 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 261 insertions(+), 2 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 8355cda95..80f50d265 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -14,7 +14,7 @@ import pytest import ot -from ot.backend import torch, tf +from ot.backend import tf, torch @pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) @@ -726,6 +726,135 @@ def test_wasserstein_bary_2d(nx, method): ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d_dtype_device(nx, method): + rng = np.random.RandomState(42) + size = 20 # size of a square image + + # First image + a1 = rng.rand(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) # Ensure that it is a probability distribution + + # Second image + a2 = rng.rand(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) # Ensure that it is a probability distribution + + # creating matrix A containing all distributions + A = np.zeros((2, size, size)) + A[0, :, :] = a1 + A[1, :, :] = a2 + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + Ab = nx.from_numpy(A, type_as=tp) + + # wasserstein + reg = 1e-2 + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + + +@pytest.mark.skipif(not tf, reason="tf not installed") +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d_device_tf(method): + # Using the Tensorflow backend + nx = ot.backend.TensorflowBackend() + + rng = np.random.RandomState(42) + size = 20 # size of a square image + + # First image + a1 = rng.rand(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) # Ensure that it is a probability distribution + + # Second image + a2 = rng.rand(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) # Ensure that it is a probability distribution + + # creating matrix A containing all distributions + A = np.zeros((2, size, size)) + A[0, :, :] = a1 + A[1, :, :] = a2 + + # Check that everything stays on the CPU + with tf.device("/CPU:0"): + Ab = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + + if len(tf.config.list_physical_devices('GPU')) > 0: + # Check that everything happens on the GPU + Ab = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + assert nx.dtype_device(bary_wass_b)[1].startswith("GPU") + + @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) def test_wasserstein_bary_2d_debiased(nx, method): rng = np.random.RandomState(42) @@ -759,7 +888,137 @@ def test_wasserstein_bary_2d_debiased(nx, method): np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d_debiased_dtype_device(nx, method): + rng = np.random.RandomState(42) + size = 20 # size of a square image + + # First image + a1 = rng.rand(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) # Ensure that it is a probability distribution + + # Second image + a2 = rng.rand(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) # Ensure that it is a probability distribution + + # creating matrix A containing all distributions + A = np.zeros((2, size, size)) + A[0, :, :] = a1 + A[1, :, :] = a2 + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + Ab = nx.from_numpy(A, type_as=tp) + + # wasserstein + reg = 1e-2 + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + + +@pytest.mark.skipif(not tf, reason="tf not installed") +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d_debiased_device_tf(method): + # Using the Tensorflow backend + nx = ot.backend.TensorflowBackend() + + rng = np.random.RandomState(42) + size = 20 # size of a square image + + # First image + a1 = rng.rand(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) # Ensure that it is a probability distribution + + # Second image + a2 = rng.rand(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) # Ensure that it is a probability distribution + + # creating matrix A containing all distributions + A = np.zeros((2, size, size)) + A[0, :, :] = a1 + A[1, :, :] = a2 + + # Check that everything stays on the CPU + with tf.device("/CPU:0"): + Ab = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + + if len(tf.config.list_physical_devices('GPU')) > 0: + # Check that everything happens on the GPU + Ab = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + assert nx.dtype_device(bary_wass_b)[1].startswith("GPU") + def test_unmix(nx): From ccd7c35591a9a5e22bb70fc654f791f2f77b3216 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Fri, 13 Oct 2023 19:52:28 -0300 Subject: [PATCH 06/12] [DEBUG] PEP 8 --- test/test_bregman.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 80f50d265..6351df1b9 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -945,7 +945,7 @@ def test_wasserstein_bary_2d_debiased_dtype_device(nx, method): def test_wasserstein_bary_2d_debiased_device_tf(method): # Using the Tensorflow backend nx = ot.backend.TensorflowBackend() - + rng = np.random.RandomState(42) size = 20 # size of a square image @@ -1020,7 +1020,6 @@ def test_wasserstein_bary_2d_debiased_device_tf(method): assert nx.dtype_device(bary_wass_b)[1].startswith("GPU") - def test_unmix(nx): n_bins = 50 # nb bins From bfaf2d167e137bc2730d939c85cacf8e45f74219 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Fri, 13 Oct 2023 19:52:52 -0300 Subject: [PATCH 07/12] [DOC] Add the new changes to RELEASES.md --- RELEASES.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index 65ad8b35f..df5ade483 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,6 +6,8 @@ + Added support for [Nearest Brenier Potentials (SSNB)](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) (PR #526) + Tweaked `get_backend` to ignore `None` inputs (PR #525) + Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507) ++ The `linspace` method of the backends now has the `type_as` argument to convert to the same dtype and device. (PR #533) ++ The `convolutional_barycenter2d` and `convolutional_barycenter2d_debiased` functions now work with different devices.. (PR #533) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) From 196644826c4c37ba8394e791710b6b54282aeed5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Fri, 13 Oct 2023 21:07:15 -0300 Subject: [PATCH 08/12] [REFACTOR] Minor refactor that checks the GPU on the last line --- test/test_bregman.py | 111 +++++++++++++++++++++++-------------------- 1 file changed, 59 insertions(+), 52 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 6351df1b9..25ab6cfe6 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -352,13 +352,15 @@ def test_sinkhorn2_variants_device_tf(method): nx.assert_same_dtype_device(Mb, Gb) nx.assert_same_dtype_device(Mb, lossb) + # Check that everything happens on the GPU + ub, Mb = nx.from_numpy(u, M) + Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) + lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, lossb) + + # Check this only if GPU is available if len(tf.config.list_physical_devices('GPU')) > 0: - # Check that everything happens on the GPU - ub, Mb = nx.from_numpy(u, M) - Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) - lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) - nx.assert_same_dtype_device(Mb, Gb) - nx.assert_same_dtype_device(Mb, lossb) assert nx.dtype_device(Gb)[1].startswith("GPU") @@ -805,7 +807,7 @@ def test_wasserstein_bary_2d_device_tf(method): # wasserstein reg = 1e-2 - if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + if method == "sinkhorn_log": with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) else: @@ -826,32 +828,34 @@ def test_wasserstein_bary_2d_device_tf(method): # Test that the dtype and device are the same after the computation nx.assert_same_dtype_device(Ab, bary_wass_b) - if len(tf.config.list_physical_devices('GPU')) > 0: - # Check that everything happens on the GPU - Ab = nx.from_numpy(A) + # Check that everything happens on the GPU + Ab = nx.from_numpy(A) - # wasserstein - reg = 1e-2 - if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": - with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) - else: - # Compute the barycenter with numpy - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( - A, reg, method=method, verbose=True, log=True) - # Compute the barycenter with the backend - bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) - # Convert the backend result to numpy, to compare with the numpy result - bary_wass = nx.to_numpy(bary_wass_b) + # wasserstein + reg = 1e-2 + if method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) - # Test that the dtype and device are the same after the computation - nx.assert_same_dtype_device(Ab, bary_wass_b) + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + + # Check this only if GPU is available + if len(tf.config.list_physical_devices('GPU')) > 0: assert nx.dtype_device(bary_wass_b)[1].startswith("GPU") @@ -970,7 +974,7 @@ def test_wasserstein_bary_2d_debiased_device_tf(method): # wasserstein reg = 1e-2 - if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + if method == "sinkhorn_log": with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) else: @@ -991,32 +995,35 @@ def test_wasserstein_bary_2d_debiased_device_tf(method): # Test that the dtype and device are the same after the computation nx.assert_same_dtype_device(Ab, bary_wass_b) - if len(tf.config.list_physical_devices('GPU')) > 0: - # Check that everything happens on the GPU - Ab = nx.from_numpy(A) + + # Check that everything happens on the GPU + Ab = nx.from_numpy(A) - # wasserstein - reg = 1e-2 - if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": - with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) - else: - # Compute the barycenter with numpy - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( - A, reg, method=method, verbose=True, log=True) - # Compute the barycenter with the backend - bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) - # Convert the backend result to numpy, to compare with the numpy result - bary_wass = nx.to_numpy(bary_wass_b) + # wasserstein + reg = 1e-2 + if method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) - # Test that the dtype and device are the same after the computation - nx.assert_same_dtype_device(Ab, bary_wass_b) + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + + # Check this only if there is a GPU + if len(tf.config.list_physical_devices('GPU')) > 0: assert nx.dtype_device(bary_wass_b)[1].startswith("GPU") From c7c6d6e7edb1e5cd7c3d1fad34d646f0b24b5cb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Fri, 13 Oct 2023 21:11:35 -0300 Subject: [PATCH 09/12] [DEBUG] pep8 --- test/test_bregman.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 25ab6cfe6..49c1bc2ac 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -995,7 +995,6 @@ def test_wasserstein_bary_2d_debiased_device_tf(method): # Test that the dtype and device are the same after the computation nx.assert_same_dtype_device(Ab, bary_wass_b) - # Check that everything happens on the GPU Ab = nx.from_numpy(A) From 2d05fceb7658f64e24c1212268da55522330af5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Mon, 16 Oct 2023 13:24:24 -0300 Subject: [PATCH 10/12] [REFACTOR] Add a function to generalize the creation of random images --- test/test_bregman.py | 694 ++++++++++++++++++++++++------------------- 1 file changed, 391 insertions(+), 303 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 49c1bc2ac..ae17a1a88 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -31,19 +31,23 @@ def test_sinkhorn(verbose, warn): G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10, verbose=verbose, warn=warn) # check constraints - np.testing.assert_allclose( - u, G.sum(1), atol=1e-05) # cf convergence sinkhorn - np.testing.assert_allclose( - u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn with pytest.warns(UserWarning): ot.sinkhorn(u, u, M, 1, stopThr=0, numItermax=1) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", - "sinkhorn_epsilon_scaling", - "greenkhorn", - "sinkhorn_log"]) +@pytest.mark.parametrize( + "method", + [ + "sinkhorn", + "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log", + ], +) def test_convergence_warning(method): # test sinkhorn n = 100 @@ -53,24 +57,26 @@ def test_convergence_warning(method): M = ot.utils.dist0(n) with pytest.warns(UserWarning): - ot.sinkhorn(a1, a2, M, 1., method=method, stopThr=0, numItermax=1) + ot.sinkhorn(a1, a2, M, 1.0, method=method, stopThr=0, numItermax=1) if method in ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]: with pytest.warns(UserWarning): ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) with pytest.warns(UserWarning): - ot.sinkhorn2(a1, a2, M, 1, method=method, - stopThr=0, numItermax=1, warn=True) + ot.sinkhorn2( + a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=True + ) with warnings.catch_warnings(): warnings.simplefilter("error") - ot.sinkhorn2(a1, a2, M, 1, method=method, - stopThr=0, numItermax=1, warn=False) + ot.sinkhorn2( + a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=False + ) def test_not_implemented_method(): # test sinkhorn w = 10 - n = w ** 2 + n = w**2 rng = np.random.RandomState(42) A_img = rng.rand(2, w, w) A_flat = A_img.reshape(n, 2) @@ -85,14 +91,13 @@ def test_not_implemented_method(): with pytest.raises(ValueError): ot.barycenter(A_flat, M_flat, reg, method=not_implemented) with pytest.raises(ValueError): - ot.bregman.barycenter_debiased(A_flat, M_flat, reg, - method=not_implemented) + ot.bregman.barycenter_debiased(A_flat, M_flat, reg, method=not_implemented) with pytest.raises(ValueError): - ot.bregman.convolutional_barycenter2d(A_img, reg, - method=not_implemented) + ot.bregman.convolutional_barycenter2d(A_img, reg, method=not_implemented) with pytest.raises(ValueError): - ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, - method=not_implemented) + ot.bregman.convolutional_barycenter2d_debiased( + A_img, reg, method=not_implemented + ) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) @@ -118,14 +123,17 @@ def test_sinkhorn_stabilization(): reg = 1e-5 loss1 = ot.sinkhorn2(a1, a2, M, reg, method="sinkhorn_log") loss2 = ot.sinkhorn2(a1, a2, M, reg, tau=1, method="sinkhorn_stabilized") - np.testing.assert_allclose( - loss1, loss2, atol=1e-06) # cf convergence sinkhorn + np.testing.assert_allclose(loss1, loss2, atol=1e-06) # cf convergence sinkhorn -@pytest.mark.parametrize("method, verbose, warn", - product(["sinkhorn", "sinkhorn_stabilized", - "sinkhorn_log"], - [True, False], [True, False])) +@pytest.mark.parametrize( + "method, verbose, warn", + product( + ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], + [True, False], + ), +) def test_sinkhorn_multi_b(method, verbose, warn): # test sinkhorn n = 10 @@ -139,14 +147,16 @@ def test_sinkhorn_multi_b(method, verbose, warn): M = ot.dist(x, x) - loss0, log = ot.sinkhorn(u, b, M, .1, method=method, stopThr=1e-10, - log=True) + loss0, log = ot.sinkhorn(u, b, M, 0.1, method=method, stopThr=1e-10, log=True) - loss = [ot.sinkhorn2(u, b[:, k], M, .1, method=method, stopThr=1e-10, - verbose=verbose, warn=warn) for k in range(3)] + loss = [ + ot.sinkhorn2( + u, b[:, k], M, 0.1, method=method, stopThr=1e-10, verbose=verbose, warn=warn + ) + for k in range(3) + ] # check constraints - np.testing.assert_allclose( - loss0, loss, atol=1e-4) # cf convergence sinkhorn + np.testing.assert_allclose(loss0, loss, atol=1e-4) # cf convergence sinkhorn def test_sinkhorn_backends(nx): @@ -201,7 +211,6 @@ def test_sinkhorn2_gradients(): M = ot.dist(x, y) if torch: - a1 = torch.tensor(a, requires_grad=True) b1 = torch.tensor(a, requires_grad=True) M1 = torch.tensor(M, requires_grad=True) @@ -225,8 +234,9 @@ def test_sinkhorn_empty(): M = ot.dist(x, x) - G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method="sinkhorn_log", - verbose=True, log=True) + G, log = ot.sinkhorn( + [], [], M, 1, stopThr=1e-10, method="sinkhorn_log", verbose=True, log=True + ) # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) @@ -236,24 +246,39 @@ def test_sinkhorn_empty(): np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) - G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, - method='sinkhorn_stabilized', verbose=True, log=True) + G, log = ot.sinkhorn( + [], + [], + M, + 1, + stopThr=1e-10, + method="sinkhorn_stabilized", + verbose=True, + log=True, + ) # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) G, log = ot.sinkhorn( - [], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling', - verbose=True, log=True) + [], + [], + M, + 1, + stopThr=1e-10, + method="sinkhorn_epsilon_scaling", + verbose=True, + log=True, + ) # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) # test empty weights greenkhorn - ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True) + ot.sinkhorn([], [], M, 1, method="greenkhorn", stopThr=1e-10, log=True) -@pytest.skip_backend('tf') +@pytest.skip_backend("tf") @pytest.skip_backend("jax") def test_sinkhorn_variants(nx): # test sinkhorn @@ -267,17 +292,18 @@ def test_sinkhorn_variants(nx): ub, M_nx = nx.from_numpy(u, M) - G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn( - ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn( - ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn( - ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) - Ges = nx.to_numpy(ot.sinkhorn( - ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) - G_green = nx.to_numpy(ot.sinkhorn( - ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) + G = ot.sinkhorn(u, u, M, 1, method="sinkhorn", stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method="sinkhorn_log", stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method="sinkhorn", stopThr=1e-10)) + Gs = nx.to_numpy( + ot.sinkhorn(ub, ub, M_nx, 1, method="sinkhorn_stabilized", stopThr=1e-10) + ) + Ges = nx.to_numpy( + ot.sinkhorn(ub, ub, M_nx, 1, method="sinkhorn_epsilon_scaling", stopThr=1e-10) + ) + G_green = nx.to_numpy( + ot.sinkhorn(ub, ub, M_nx, 1, method="greenkhorn", stopThr=1e-10) + ) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -287,14 +313,40 @@ def test_sinkhorn_variants(nx): np.testing.assert_allclose(G0, G_green, atol=1e-5) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", - "sinkhorn_epsilon_scaling", - "greenkhorn", - "sinkhorn_log"]) -@pytest.skip_arg(("nx", "method"), ("tf", "sinkhorn_epsilon_scaling"), reason="tf does not support sinkhorn_epsilon_scaling", getter=str) -@pytest.skip_arg(("nx", "method"), ("tf", "greenkhorn"), reason="tf does not support greenkhorn", getter=str) -@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str) -@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str) +@pytest.mark.parametrize( + "method", + [ + "sinkhorn", + "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log", + ], +) +@pytest.skip_arg( + ("nx", "method"), + ("tf", "sinkhorn_epsilon_scaling"), + reason="tf does not support sinkhorn_epsilon_scaling", + getter=str, +) +@pytest.skip_arg( + ("nx", "method"), + ("tf", "greenkhorn"), + reason="tf does not support greenkhorn", + getter=str, +) +@pytest.skip_arg( + ("nx", "method"), + ("jax", "sinkhorn_epsilon_scaling"), + reason="jax does not support sinkhorn_epsilon_scaling", + getter=str, +) +@pytest.skip_arg( + ("nx", "method"), + ("jax", "greenkhorn"), + reason="jax does not support greenkhorn", + getter=str, +) def test_sinkhorn_variants_dtype_device(nx, method): n = 100 @@ -360,11 +412,11 @@ def test_sinkhorn2_variants_device_tf(method): nx.assert_same_dtype_device(Mb, lossb) # Check this only if GPU is available - if len(tf.config.list_physical_devices('GPU')) > 0: + if len(tf.config.list_physical_devices("GPU")) > 0: assert nx.dtype_device(Gb)[1].startswith("GPU") -@pytest.skip_backend('tf') +@pytest.skip_backend("tf") @pytest.skip_backend("jax") def test_sinkhorn_variants_multi_b(nx): # test sinkhorn @@ -381,13 +433,12 @@ def test_sinkhorn_variants_multi_b(nx): ub, bb, M_nx = nx.from_numpy(u, b, M) - G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn( - ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn( - ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn( - ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + G = ot.sinkhorn(u, b, M, 1, method="sinkhorn", stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method="sinkhorn_log", stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method="sinkhorn", stopThr=1e-10)) + Gs = nx.to_numpy( + ot.sinkhorn(ub, bb, M_nx, 1, method="sinkhorn_stabilized", stopThr=1e-10) + ) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -395,7 +446,7 @@ def test_sinkhorn_variants_multi_b(nx): np.testing.assert_allclose(G0, Gs, atol=1e-05) -@pytest.skip_backend('tf') +@pytest.skip_backend("tf") @pytest.skip_backend("jax") def test_sinkhorn2_variants_multi_b(nx): # test sinkhorn @@ -412,13 +463,14 @@ def test_sinkhorn2_variants_multi_b(nx): ub, bb, M_nx = nx.from_numpy(u, b, M) - G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn2( - ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn2( - ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn2( - ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + G = ot.sinkhorn2(u, b, M, 1, method="sinkhorn", stopThr=1e-10) + Gl = nx.to_numpy( + ot.sinkhorn2(ub, bb, M_nx, 1, method="sinkhorn_log", stopThr=1e-10) + ) + G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method="sinkhorn", stopThr=1e-10)) + Gs = nx.to_numpy( + ot.sinkhorn2(ub, bb, M_nx, 1, method="sinkhorn_stabilized", stopThr=1e-10) + ) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -436,16 +488,23 @@ def test_sinkhorn_variants_log(): M = ot.dist(x, x) - G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', - stopThr=1e-10, log=True) - Gl, logl = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) + G0, log0 = ot.sinkhorn(u, u, M, 1, method="sinkhorn", stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, u, M, 1, method="sinkhorn_log", stopThr=1e-10, log=True) Gs, logs = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + u, u, M, 1, method="sinkhorn_stabilized", stopThr=1e-10, log=True + ) Ges, loges = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,) + u, + u, + M, + 1, + method="sinkhorn_epsilon_scaling", + stopThr=1e-10, + log=True, + ) G_green, loggreen = ot.sinkhorn( - u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) + u, u, M, 1, method="greenkhorn", stopThr=1e-10, log=True + ) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) @@ -467,21 +526,43 @@ def test_sinkhorn_variants_log_multib(verbose, warn): M = ot.dist(x, x) - G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', - stopThr=1e-10, log=True) - Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True, - verbose=verbose, warn=warn) - Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True, - verbose=verbose, warn=warn) + G0, log0 = ot.sinkhorn(u, b, M, 1, method="sinkhorn", stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn( + u, + b, + M, + 1, + method="sinkhorn_log", + stopThr=1e-10, + log=True, + verbose=verbose, + warn=warn, + ) + Gs, logs = ot.sinkhorn( + u, + b, + M, + 1, + method="sinkhorn_stabilized", + stopThr=1e-10, + log=True, + verbose=verbose, + warn=warn, + ) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Gl, atol=1e-05) -@pytest.mark.parametrize("method, verbose, warn", - product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], - [True, False], [True, False])) +@pytest.mark.parametrize( + "method, verbose, warn", + product( + ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], + [True, False], + ), +) def test_barycenter(nx, method, verbose, warn): n_bins = 100 # nb bins @@ -508,9 +589,11 @@ def test_barycenter(nx, method, verbose, warn): else: # wasserstein bary_wass_np = ot.bregman.barycenter( - A, M, reg, weights, method=method, verbose=verbose, warn=warn) + A, M, reg, weights, method=method, verbose=verbose, warn=warn + ) bary_wass, _ = ot.bregman.barycenter( - A_nx, M_nx, reg, weights_nx, method=method, log=True) + A_nx, M_nx, reg, weights_nx, method=method, log=True + ) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass)) @@ -521,33 +604,39 @@ def test_barycenter(nx, method, verbose, warn): def test_free_support_sinkhorn_barycenter(): measures_locations = [ - np.array([-1.]).reshape((1, 1)), # First dirac support - np.array([1.]).reshape((1, 1)) # Second dirac support + np.array([-1.0]).reshape((1, 1)), # First dirac support + np.array([1.0]).reshape((1, 1)), # Second dirac support ] measures_weights = [ - np.array([1.]), # First dirac sample weights - np.array([1.]) # Second dirac sample weights + np.array([1.0]), # First dirac sample weights + np.array([1.0]), # Second dirac sample weights ] # Barycenter initialization - X_init = np.array([-12.]).reshape((1, 1)) + X_init = np.array([-12.0]).reshape((1, 1)) # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter - bar_locations = np.array([0.]).reshape((1, 1)) + bar_locations = np.array([0.0]).reshape((1, 1)) # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization # term to 1, but this should be, in general, fine-tuned to the problem. X = ot.bregman.free_support_sinkhorn_barycenter( - measures_locations, measures_weights, X_init, reg=1) + measures_locations, measures_weights, X_init, reg=1 + ) # Verifies if calculated barycenter matches ground-truth np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) -@pytest.mark.parametrize("method, verbose, warn", - product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], - [True, False], [True, False])) +@pytest.mark.parametrize( + "method, verbose, warn", + product( + ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], + [True, False], + ), +) def test_barycenter_assymetric_cost(nx, method, verbose, warn): n_bins = 20 # nb bins @@ -571,9 +660,9 @@ def test_barycenter_assymetric_cost(nx, method, verbose, warn): else: # wasserstein bary_wass_np = ot.bregman.barycenter( - A, M, reg, method=method, verbose=verbose, warn=warn) - bary_wass, _ = ot.bregman.barycenter( - A_nx, M_nx, reg, method=method, log=True) + A, M, reg, method=method, verbose=verbose, warn=warn + ) + bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass)) @@ -582,9 +671,10 @@ def test_barycenter_assymetric_cost(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, log=True) -@pytest.mark.parametrize("method, verbose, warn", - product(["sinkhorn", "sinkhorn_log"], - [True, False], [True, False])) +@pytest.mark.parametrize( + "method, verbose, warn", + product(["sinkhorn", "sinkhorn_log"], [True, False], [True, False]), +) def test_barycenter_debiased(nx, method, verbose, warn): n_bins = 100 # nb bins @@ -608,26 +698,26 @@ def test_barycenter_debiased(nx, method, verbose, warn): reg = 1e-2 if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): - ot.bregman.barycenter_debiased( - A_nx, M_nx, reg, weights, method=method) + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) else: - bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, - verbose=verbose, warn=warn) + bary_wass_np = ot.bregman.barycenter_debiased( + A, M, reg, weights, method=method, verbose=verbose, warn=warn + ) bary_wass, _ = ot.bregman.barycenter_debiased( - A_nx, M_nx, reg, weights_nx, method=method, log=True) + A_nx, M_nx, reg, weights_nx, method=method, log=True + ) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) - ot.bregman.barycenter_debiased( - A_nx, M_nx, reg, log=True, verbose=False) + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) def test_convergence_warning_barycenters(method): w = 10 - n_bins = w ** 2 # nb bins + n_bins = w**2 # nb bins # Gaussian distributions a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std @@ -646,16 +736,17 @@ def test_convergence_warning_barycenters(method): weights = np.array([1 - alpha, alpha]) reg = 0.1 with pytest.warns(UserWarning): - ot.bregman.barycenter_debiased( - A, M, reg, weights, method=method, numItermax=1) + ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1) with pytest.warns(UserWarning): ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1) with pytest.warns(UserWarning): - ot.bregman.convolutional_barycenter2d(A_img, reg, weights, - method=method, numItermax=1) + ot.bregman.convolutional_barycenter2d( + A_img, reg, weights, method=method, numItermax=1 + ) with pytest.warns(UserWarning): - ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, weights, - method=method, numItermax=1) + ot.bregman.convolutional_barycenter2d_debiased( + A_img, reg, weights, method=method, numItermax=1 + ) def test_barycenter_stabilization(nx): @@ -680,34 +771,55 @@ def test_barycenter_stabilization(nx): # wasserstein reg = 1e-2 bar_np = ot.bregman.barycenter( - A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) - bar_stable = nx.to_numpy(ot.bregman.barycenter( - A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized", - stopThr=1e-8, verbose=True - )) - bar = nx.to_numpy(ot.bregman.barycenter( - A_nx, M_nx, reg, weights_b, method="sinkhorn", - stopThr=1e-8, verbose=True - )) + A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True + ) + bar_stable = nx.to_numpy( + ot.bregman.barycenter( + A_nx, + M_nx, + reg, + weights_b, + method="sinkhorn_stabilized", + stopThr=1e-8, + verbose=True, + ) + ) + bar = nx.to_numpy( + ot.bregman.barycenter( + A_nx, M_nx, reg, weights_b, method="sinkhorn", stopThr=1e-8, verbose=True + ) + ) np.testing.assert_allclose(bar, bar_stable) np.testing.assert_allclose(bar, bar_np) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) -def test_wasserstein_bary_2d(nx, method): - rng = np.random.RandomState(42) - size = 20 # size of a square image +def create_random_images_dist(seed, size=20): + """Creates an array of two random images of size (size, size). Returns an array of shape (2, size, size).""" + rng = np.random.RandomState(seed) + + # First image a1 = rng.rand(size, size) a1 += a1.min() - a1 = a1 / np.sum(a1) + a1 = a1 / np.sum(a1) # Ensure that it is a probability distribution + + # Second image a2 = rng.rand(size, size) a2 += a2.min() a2 = a2 / np.sum(a2) - # creating matrix A containing all distributions + + # Creating matrix A containing all distributions A = np.zeros((2, size, size)) A[0, :, :] = a1 A[1, :, :] = a2 + return A + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d(nx, method): + # Create the array of images to test + A = create_random_images_dist(42, size=20) + A_nx = nx.from_numpy(A) # wasserstein @@ -717,9 +829,11 @@ def test_wasserstein_bary_2d(nx, method): ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) else: bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( - A, reg, method=method, verbose=True, log=True) + A, reg, method=method, verbose=True, log=True + ) bary_wass = nx.to_numpy( - ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) + ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) + ) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) @@ -730,23 +844,8 @@ def test_wasserstein_bary_2d(nx, method): @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) def test_wasserstein_bary_2d_dtype_device(nx, method): - rng = np.random.RandomState(42) - size = 20 # size of a square image - - # First image - a1 = rng.rand(size, size) - a1 += a1.min() - a1 = a1 / np.sum(a1) # Ensure that it is a probability distribution - - # Second image - a2 = rng.rand(size, size) - a2 += a2.min() - a2 = a2 / np.sum(a2) # Ensure that it is a probability distribution - - # creating matrix A containing all distributions - A = np.zeros((2, size, size)) - A[0, :, :] = a1 - A[1, :, :] = a2 + # Create the array of images to test + A = create_random_images_dist(42, size=20) for tp in nx.__type_list__: print(nx.dtype_device(tp)) @@ -761,7 +860,8 @@ def test_wasserstein_bary_2d_dtype_device(nx, method): else: # Compute the barycenter with numpy bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( - A, reg, method=method, verbose=True, log=True) + A, reg, method=method, verbose=True, log=True + ) # Compute the barycenter with the backend bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) # Convert the backend result to numpy, to compare with the numpy result @@ -783,23 +883,8 @@ def test_wasserstein_bary_2d_device_tf(method): # Using the Tensorflow backend nx = ot.backend.TensorflowBackend() - rng = np.random.RandomState(42) - size = 20 # size of a square image - - # First image - a1 = rng.rand(size, size) - a1 += a1.min() - a1 = a1 / np.sum(a1) # Ensure that it is a probability distribution - - # Second image - a2 = rng.rand(size, size) - a2 += a2.min() - a2 = a2 / np.sum(a2) # Ensure that it is a probability distribution - - # creating matrix A containing all distributions - A = np.zeros((2, size, size)) - A[0, :, :] = a1 - A[1, :, :] = a2 + # Create the array of images to test + A = create_random_images_dist(42, size=20) # Check that everything stays on the CPU with tf.device("/CPU:0"): @@ -813,7 +898,8 @@ def test_wasserstein_bary_2d_device_tf(method): else: # Compute the barycenter with numpy bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( - A, reg, method=method, verbose=True, log=True) + A, reg, method=method, verbose=True, log=True + ) # Compute the barycenter with the backend bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) # Convert the backend result to numpy, to compare with the numpy result @@ -839,7 +925,8 @@ def test_wasserstein_bary_2d_device_tf(method): else: # Compute the barycenter with numpy bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( - A, reg, method=method, verbose=True, log=True) + A, reg, method=method, verbose=True, log=True + ) # Compute the barycenter with the backend bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) # Convert the backend result to numpy, to compare with the numpy result @@ -855,24 +942,14 @@ def test_wasserstein_bary_2d_device_tf(method): nx.assert_same_dtype_device(Ab, bary_wass_b) # Check this only if GPU is available - if len(tf.config.list_physical_devices('GPU')) > 0: + if len(tf.config.list_physical_devices("GPU")) > 0: assert nx.dtype_device(bary_wass_b)[1].startswith("GPU") @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) def test_wasserstein_bary_2d_debiased(nx, method): - rng = np.random.RandomState(42) - size = 20 # size of a square image - a1 = rng.rand(size, size) - a1 += a1.min() - a1 = a1 / np.sum(a1) - a2 = rng.rand(size, size) - a2 += a2.min() - a2 = a2 / np.sum(a2) - # creating matrix A containing all distributions - A = np.zeros((2, size, size)) - A[0, :, :] = a1 - A[1, :, :] = a2 + # Create the array of images to test + A = create_random_images_dist(42, size=20) A_nx = nx.from_numpy(A) @@ -880,13 +957,14 @@ def test_wasserstein_bary_2d_debiased(nx, method): reg = 1e-2 if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d_debiased( - A_nx, reg, method=method) + ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) else: bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( - A, reg, method=method, verbose=True, log=True) + A, reg, method=method, verbose=True, log=True + ) bary_wass = nx.to_numpy( - ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) + ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) + ) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) @@ -897,23 +975,8 @@ def test_wasserstein_bary_2d_debiased(nx, method): @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) def test_wasserstein_bary_2d_debiased_dtype_device(nx, method): - rng = np.random.RandomState(42) - size = 20 # size of a square image - - # First image - a1 = rng.rand(size, size) - a1 += a1.min() - a1 = a1 / np.sum(a1) # Ensure that it is a probability distribution - - # Second image - a2 = rng.rand(size, size) - a2 += a2.min() - a2 = a2 / np.sum(a2) # Ensure that it is a probability distribution - - # creating matrix A containing all distributions - A = np.zeros((2, size, size)) - A[0, :, :] = a1 - A[1, :, :] = a2 + # Create the array of images to test + A = create_random_images_dist(42, size=20) for tp in nx.__type_list__: print(nx.dtype_device(tp)) @@ -928,9 +991,12 @@ def test_wasserstein_bary_2d_debiased_dtype_device(nx, method): else: # Compute the barycenter with numpy bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( - A, reg, method=method, verbose=True, log=True) + A, reg, method=method, verbose=True, log=True + ) # Compute the barycenter with the backend - bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased( + Ab, reg, method=method + ) # Convert the backend result to numpy, to compare with the numpy result bary_wass = nx.to_numpy(bary_wass_b) @@ -938,7 +1004,9 @@ def test_wasserstein_bary_2d_debiased_dtype_device(nx, method): np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) + ot.bregman.convolutional_barycenter2d_debiased( + A, reg, log=True, verbose=True + ) # Test that the dtype and device are the same after the computation nx.assert_same_dtype_device(Ab, bary_wass_b) @@ -950,23 +1018,8 @@ def test_wasserstein_bary_2d_debiased_device_tf(method): # Using the Tensorflow backend nx = ot.backend.TensorflowBackend() - rng = np.random.RandomState(42) - size = 20 # size of a square image - - # First image - a1 = rng.rand(size, size) - a1 += a1.min() - a1 = a1 / np.sum(a1) # Ensure that it is a probability distribution - - # Second image - a2 = rng.rand(size, size) - a2 += a2.min() - a2 = a2 / np.sum(a2) # Ensure that it is a probability distribution - - # creating matrix A containing all distributions - A = np.zeros((2, size, size)) - A[0, :, :] = a1 - A[1, :, :] = a2 + # Create the array of images to test + A = create_random_images_dist(42, size=20) # Check that everything stays on the CPU with tf.device("/CPU:0"): @@ -980,9 +1033,12 @@ def test_wasserstein_bary_2d_debiased_device_tf(method): else: # Compute the barycenter with numpy bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( - A, reg, method=method, verbose=True, log=True) + A, reg, method=method, verbose=True, log=True + ) # Compute the barycenter with the backend - bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased( + Ab, reg, method=method + ) # Convert the backend result to numpy, to compare with the numpy result bary_wass = nx.to_numpy(bary_wass_b) @@ -990,7 +1046,9 @@ def test_wasserstein_bary_2d_debiased_device_tf(method): np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) + ot.bregman.convolutional_barycenter2d_debiased( + A, reg, log=True, verbose=True + ) # Test that the dtype and device are the same after the computation nx.assert_same_dtype_device(Ab, bary_wass_b) @@ -1006,9 +1064,12 @@ def test_wasserstein_bary_2d_debiased_device_tf(method): else: # Compute the barycenter with numpy bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( - A, reg, method=method, verbose=True, log=True) + A, reg, method=method, verbose=True, log=True + ) # Compute the barycenter with the backend - bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased( + Ab, reg, method=method + ) # Convert the backend result to numpy, to compare with the numpy result bary_wass = nx.to_numpy(bary_wass_b) @@ -1022,7 +1083,7 @@ def test_wasserstein_bary_2d_debiased_device_tf(method): nx.assert_same_dtype_device(Ab, bary_wass_b) # Check this only if there is a GPU - if len(tf.config.list_physical_devices('GPU')) > 0: + if len(tf.config.list_physical_devices("GPU")) > 0: assert nx.dtype_device(bary_wass_b)[1].startswith("GPU") @@ -1051,15 +1112,13 @@ def test_unmix(nx): # wasserstein reg = 1e-3 um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) - um = nx.to_numpy(ot.bregman.unmix( - ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) + um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) np.testing.assert_allclose(um, um_np) - ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, - 1, alpha=0.01, log=True, verbose=True) + ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01, log=True, verbose=True) def test_empirical_sinkhorn(nx): @@ -1071,7 +1130,7 @@ def test_empirical_sinkhorn(nx): X_s = np.reshape(1.0 * np.arange(n), (n, 1)) X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric='euclidean') + M_m = ot.dist(X_s, X_t, metric="euclidean") ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) @@ -1083,27 +1142,27 @@ def test_empirical_sinkhorn(nx): sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn( - X_sb, X_tb, 1, metric='euclidean')) + G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric="euclidean")) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn = nx.to_numpy( - ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) + loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints np.testing.assert_allclose( - sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian - np.testing.assert_allclose( - sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05 + ) # metric sqeuclidian np.testing.assert_allclose( - sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05 + ) # metric sqeuclidian + np.testing.assert_allclose(sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + np.testing.assert_allclose(sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log np.testing.assert_allclose( - sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05 + ) # metric euclidian np.testing.assert_allclose( - sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian - np.testing.assert_allclose( - sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian + sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05 + ) # metric euclidian np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) @@ -1117,47 +1176,65 @@ def test_lazy_empirical_sinkhorn(nx): X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) X_t = np.reshape(np.arange(0, n, dtype=np.float64), (n, 1)) M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric='euclidean') + M_m = ot.dist(X_s, X_t, metric="euclidean") ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) f, g = ot.bregman.empirical_sinkhorn( - X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) + X_sb, + X_tb, + 1, + numIterMax=numIterMax, + isLazy=True, + batchSize=(1, 3), + verbose=True, + ) f, g = nx.to_numpy(f), nx.to_numpy(g) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) f, g, log_es = ot.bregman.empirical_sinkhorn( - X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True + ) f, g = nx.to_numpy(f), nx.to_numpy(g) G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) f, g = ot.bregman.empirical_sinkhorn( - X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) + X_sb, + X_tb, + 1, + metric="euclidean", + numIterMax=numIterMax, + isLazy=True, + batchSize=1, + ) f, g = nx.to_numpy(f), nx.to_numpy(g) G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2( - X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True + ) loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints np.testing.assert_allclose( - sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian - np.testing.assert_allclose( - sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05 + ) # metric sqeuclidian np.testing.assert_allclose( - sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05 + ) # metric sqeuclidian + np.testing.assert_allclose(sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + np.testing.assert_allclose(sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log np.testing.assert_allclose( - sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05 + ) # metric euclidian np.testing.assert_allclose( - sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian - np.testing.assert_allclose( - sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian + sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05 + ) # metric euclidian np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) @@ -1173,27 +1250,27 @@ def test_empirical_sinkhorn_divergence(nx): M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy( - a, b, X_s, X_t, M, M_s, M_t) + ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t) emp_sinkhorn_div = nx.to_numpy( - ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) + ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb) + ) sinkhorn_div = nx.to_numpy( ot.sinkhorn2(ab, bb, M_nx, 1) - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) ) emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence( - X_s, X_t, 1, a=a, b=b) + X_s, X_t, 1, a=a, b=b + ) # check constraints + np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) np.testing.assert_allclose( - emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) - np.testing.assert_allclose( - emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn + emp_sinkhorn_div, sinkhorn_div, atol=1e-05 + ) # cf conv emp sinkhorn - ot.bregman.empirical_sinkhorn_divergence( - X_sb, X_tb, 1, a=ab, b=bb, log=True) + ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True) @pytest.mark.skipif(not torch, reason="No torch available") @@ -1216,7 +1293,8 @@ def test_empirical_sinkhorn_divergence_gradient(): X_tb.requires_grad = True emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence( - X_sb, X_tb, 1, a=ab, b=bb) + X_sb, X_tb, 1, a=ab, b=bb + ) emp_sinkhorn_div.backward() @@ -1245,14 +1323,12 @@ def test_stabilized_vs_sinkhorn_multidim(nx): ab, bb, M_nx = nx.from_numpy(a, b, M) - G_np, _ = ot.bregman.sinkhorn( - a, b, M, reg=epsilon, method="sinkhorn", log=True) - G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, - method="sinkhorn_stabilized", - log=True) + G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) + G, log = ot.bregman.sinkhorn( + ab, bb, M_nx, reg=epsilon, method="sinkhorn_stabilized", log=True + ) G = nx.to_numpy(G) - G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon, - method="sinkhorn", log=True) + G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon, method="sinkhorn", log=True) G2 = nx.to_numpy(G2) np.testing.assert_allclose(G_np, G2) @@ -1260,9 +1336,9 @@ def test_stabilized_vs_sinkhorn_multidim(nx): def test_implemented_methods(): - IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] - ONLY_1D_methods = ['greenkhorn', 'sinkhorn_epsilon_scaling'] - NOT_VALID_TOKENS = ['foo'] + IMPLEMENTED_METHODS = ["sinkhorn", "sinkhorn_stabilized"] + ONLY_1D_methods = ["greenkhorn", "sinkhorn_epsilon_scaling"] + NOT_VALID_TOKENS = ["foo"] # test generalized sinkhorn for unbalanced OT barycenter n = 3 rng = np.random.RandomState(42) @@ -1292,7 +1368,7 @@ def test_implemented_methods(): ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) -@pytest.skip_backend('tf') +@pytest.skip_backend("tf") @pytest.skip_backend("cupy") @pytest.skip_backend("jax") @pytest.mark.filterwarnings("ignore:Bottleneck") @@ -1311,8 +1387,9 @@ def test_screenkhorn(nx): # sinkhorn G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn( - ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) + G_screen = nx.to_numpy( + ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True) + ) # check marginals np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) @@ -1348,16 +1425,21 @@ def test_sinkhorn_warmstart(): # Optimal plan with uniform warmstart pi_unif, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn", log=True, warmstart=None) + a, b, M, reg, method="sinkhorn", log=True, warmstart=None + ) # Optimal plan with warmstart generated from unregularized OT pi_sh, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn", log=True, warmstart=warmstart) + a, b, M, reg, method="sinkhorn", log=True, warmstart=warmstart + ) pi_sh_log, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn_log", log=True, warmstart=warmstart) + a, b, M, reg, method="sinkhorn_log", log=True, warmstart=warmstart + ) pi_sh_stab, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn_stabilized", log=True, warmstart=warmstart) + a, b, M, reg, method="sinkhorn_stabilized", log=True, warmstart=warmstart + ) pi_sh_sc, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn_epsilon_scaling", log=True, warmstart=warmstart) + a, b, M, reg, method="sinkhorn_epsilon_scaling", log=True, warmstart=warmstart + ) np.testing.assert_allclose(pi_unif, pi_sh, atol=1e-05) np.testing.assert_allclose(pi_unif, pi_sh_log, atol=1e-05) @@ -1381,14 +1463,17 @@ def test_empirical_sinkhorn_warmstart(): # Optimal plan with uniform warmstart f, g, _ = ot.bregman.empirical_sinkhorn( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None + ) pi_unif = np.exp(f[:, None] + g[None, :] - M / reg) # Optimal plan with warmstart generated from unregularized OT f, g, _ = ot.bregman.empirical_sinkhorn( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart + ) pi_ws_lazy = np.exp(f[:, None] + g[None, :] - M / reg) pi_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn( - X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart) + X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart + ) np.testing.assert_allclose(pi_unif, pi_ws_lazy, atol=1e-05) np.testing.assert_allclose(pi_unif, pi_ws_not_lazy, atol=1e-05) @@ -1410,12 +1495,15 @@ def test_empirical_sinkhorn_divergence_warmstart(): # Optimal plan with uniform warmstart sd_unif, _ = ot.bregman.empirical_sinkhorn_divergence( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None + ) # Optimal plan with warmstart generated from unregularized OT sd_ws_lazy, _ = ot.bregman.empirical_sinkhorn_divergence( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart + ) sd_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn_divergence( - X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart) + X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart + ) np.testing.assert_allclose(sd_unif, sd_ws_lazy, atol=1e-05) np.testing.assert_allclose(sd_unif, sd_ws_not_lazy, atol=1e-05) From d2216a030930aa1c1154f5756d247ada8ac891c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Mon, 16 Oct 2023 15:32:25 -0300 Subject: [PATCH 11/12] [REFACTOR] Mantain th same style as before --- test/test_bregman.py | 534 ++++++++++++++++--------------------------- 1 file changed, 198 insertions(+), 336 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index ae17a1a88..8627df3c6 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -31,23 +31,19 @@ def test_sinkhorn(verbose, warn): G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10, verbose=verbose, warn=warn) # check constraints - np.testing.assert_allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn - np.testing.assert_allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose( + u, G.sum(1), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose( + u, G.sum(0), atol=1e-05) # cf convergence sinkhorn with pytest.warns(UserWarning): ot.sinkhorn(u, u, M, 1, stopThr=0, numItermax=1) -@pytest.mark.parametrize( - "method", - [ - "sinkhorn", - "sinkhorn_stabilized", - "sinkhorn_epsilon_scaling", - "greenkhorn", - "sinkhorn_log", - ], -) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log"]) def test_convergence_warning(method): # test sinkhorn n = 100 @@ -57,26 +53,24 @@ def test_convergence_warning(method): M = ot.utils.dist0(n) with pytest.warns(UserWarning): - ot.sinkhorn(a1, a2, M, 1.0, method=method, stopThr=0, numItermax=1) + ot.sinkhorn(a1, a2, M, 1., method=method, stopThr=0, numItermax=1) if method in ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]: with pytest.warns(UserWarning): ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) with pytest.warns(UserWarning): - ot.sinkhorn2( - a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=True - ) + ot.sinkhorn2(a1, a2, M, 1, method=method, + stopThr=0, numItermax=1, warn=True) with warnings.catch_warnings(): warnings.simplefilter("error") - ot.sinkhorn2( - a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=False - ) + ot.sinkhorn2(a1, a2, M, 1, method=method, + stopThr=0, numItermax=1, warn=False) def test_not_implemented_method(): # test sinkhorn w = 10 - n = w**2 + n = w ** 2 rng = np.random.RandomState(42) A_img = rng.rand(2, w, w) A_flat = A_img.reshape(n, 2) @@ -91,13 +85,14 @@ def test_not_implemented_method(): with pytest.raises(ValueError): ot.barycenter(A_flat, M_flat, reg, method=not_implemented) with pytest.raises(ValueError): - ot.bregman.barycenter_debiased(A_flat, M_flat, reg, method=not_implemented) + ot.bregman.barycenter_debiased(A_flat, M_flat, reg, + method=not_implemented) with pytest.raises(ValueError): - ot.bregman.convolutional_barycenter2d(A_img, reg, method=not_implemented) + ot.bregman.convolutional_barycenter2d(A_img, reg, + method=not_implemented) with pytest.raises(ValueError): - ot.bregman.convolutional_barycenter2d_debiased( - A_img, reg, method=not_implemented - ) + ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, + method=not_implemented) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) @@ -123,17 +118,14 @@ def test_sinkhorn_stabilization(): reg = 1e-5 loss1 = ot.sinkhorn2(a1, a2, M, reg, method="sinkhorn_log") loss2 = ot.sinkhorn2(a1, a2, M, reg, tau=1, method="sinkhorn_stabilized") - np.testing.assert_allclose(loss1, loss2, atol=1e-06) # cf convergence sinkhorn + np.testing.assert_allclose( + loss1, loss2, atol=1e-06) # cf convergence sinkhorn -@pytest.mark.parametrize( - "method, verbose, warn", - product( - ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], - [True, False], - [True, False], - ), -) +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_log"], + [True, False], [True, False])) def test_sinkhorn_multi_b(method, verbose, warn): # test sinkhorn n = 10 @@ -147,16 +139,14 @@ def test_sinkhorn_multi_b(method, verbose, warn): M = ot.dist(x, x) - loss0, log = ot.sinkhorn(u, b, M, 0.1, method=method, stopThr=1e-10, log=True) + loss0, log = ot.sinkhorn(u, b, M, .1, method=method, stopThr=1e-10, + log=True) - loss = [ - ot.sinkhorn2( - u, b[:, k], M, 0.1, method=method, stopThr=1e-10, verbose=verbose, warn=warn - ) - for k in range(3) - ] + loss = [ot.sinkhorn2(u, b[:, k], M, .1, method=method, stopThr=1e-10, + verbose=verbose, warn=warn) for k in range(3)] # check constraints - np.testing.assert_allclose(loss0, loss, atol=1e-4) # cf convergence sinkhorn + np.testing.assert_allclose( + loss0, loss, atol=1e-4) # cf convergence sinkhorn def test_sinkhorn_backends(nx): @@ -211,6 +201,7 @@ def test_sinkhorn2_gradients(): M = ot.dist(x, y) if torch: + a1 = torch.tensor(a, requires_grad=True) b1 = torch.tensor(a, requires_grad=True) M1 = torch.tensor(M, requires_grad=True) @@ -234,9 +225,8 @@ def test_sinkhorn_empty(): M = ot.dist(x, x) - G, log = ot.sinkhorn( - [], [], M, 1, stopThr=1e-10, method="sinkhorn_log", verbose=True, log=True - ) + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method="sinkhorn_log", + verbose=True, log=True) # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) @@ -246,39 +236,24 @@ def test_sinkhorn_empty(): np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) - G, log = ot.sinkhorn( - [], - [], - M, - 1, - stopThr=1e-10, - method="sinkhorn_stabilized", - verbose=True, - log=True, - ) + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, + method='sinkhorn_stabilized', verbose=True, log=True) # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) G, log = ot.sinkhorn( - [], - [], - M, - 1, - stopThr=1e-10, - method="sinkhorn_epsilon_scaling", - verbose=True, - log=True, - ) + [], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling', + verbose=True, log=True) # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) # test empty weights greenkhorn - ot.sinkhorn([], [], M, 1, method="greenkhorn", stopThr=1e-10, log=True) + ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True) -@pytest.skip_backend("tf") +@pytest.skip_backend('tf') @pytest.skip_backend("jax") def test_sinkhorn_variants(nx): # test sinkhorn @@ -292,18 +267,17 @@ def test_sinkhorn_variants(nx): ub, M_nx = nx.from_numpy(u, M) - G = ot.sinkhorn(u, u, M, 1, method="sinkhorn", stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method="sinkhorn_log", stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method="sinkhorn", stopThr=1e-10)) - Gs = nx.to_numpy( - ot.sinkhorn(ub, ub, M_nx, 1, method="sinkhorn_stabilized", stopThr=1e-10) - ) - Ges = nx.to_numpy( - ot.sinkhorn(ub, ub, M_nx, 1, method="sinkhorn_epsilon_scaling", stopThr=1e-10) - ) - G_green = nx.to_numpy( - ot.sinkhorn(ub, ub, M_nx, 1, method="greenkhorn", stopThr=1e-10) - ) + G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Ges = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) + G_green = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -313,40 +287,14 @@ def test_sinkhorn_variants(nx): np.testing.assert_allclose(G0, G_green, atol=1e-5) -@pytest.mark.parametrize( - "method", - [ - "sinkhorn", - "sinkhorn_stabilized", - "sinkhorn_epsilon_scaling", - "greenkhorn", - "sinkhorn_log", - ], -) -@pytest.skip_arg( - ("nx", "method"), - ("tf", "sinkhorn_epsilon_scaling"), - reason="tf does not support sinkhorn_epsilon_scaling", - getter=str, -) -@pytest.skip_arg( - ("nx", "method"), - ("tf", "greenkhorn"), - reason="tf does not support greenkhorn", - getter=str, -) -@pytest.skip_arg( - ("nx", "method"), - ("jax", "sinkhorn_epsilon_scaling"), - reason="jax does not support sinkhorn_epsilon_scaling", - getter=str, -) -@pytest.skip_arg( - ("nx", "method"), - ("jax", "greenkhorn"), - reason="jax does not support greenkhorn", - getter=str, -) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log"]) +@pytest.skip_arg(("nx", "method"), ("tf", "sinkhorn_epsilon_scaling"), reason="tf does not support sinkhorn_epsilon_scaling", getter=str) +@pytest.skip_arg(("nx", "method"), ("tf", "greenkhorn"), reason="tf does not support greenkhorn", getter=str) +@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str) +@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str) def test_sinkhorn_variants_dtype_device(nx, method): n = 100 @@ -412,11 +360,11 @@ def test_sinkhorn2_variants_device_tf(method): nx.assert_same_dtype_device(Mb, lossb) # Check this only if GPU is available - if len(tf.config.list_physical_devices("GPU")) > 0: + if len(tf.config.list_physical_devices('GPU')) > 0: assert nx.dtype_device(Gb)[1].startswith("GPU") -@pytest.skip_backend("tf") +@pytest.skip_backend('tf') @pytest.skip_backend("jax") def test_sinkhorn_variants_multi_b(nx): # test sinkhorn @@ -433,12 +381,13 @@ def test_sinkhorn_variants_multi_b(nx): ub, bb, M_nx = nx.from_numpy(u, b, M) - G = ot.sinkhorn(u, b, M, 1, method="sinkhorn", stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method="sinkhorn_log", stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method="sinkhorn", stopThr=1e-10)) - Gs = nx.to_numpy( - ot.sinkhorn(ub, bb, M_nx, 1, method="sinkhorn_stabilized", stopThr=1e-10) - ) + G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn( + ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn( + ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn( + ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -446,7 +395,7 @@ def test_sinkhorn_variants_multi_b(nx): np.testing.assert_allclose(G0, Gs, atol=1e-05) -@pytest.skip_backend("tf") +@pytest.skip_backend('tf') @pytest.skip_backend("jax") def test_sinkhorn2_variants_multi_b(nx): # test sinkhorn @@ -463,14 +412,13 @@ def test_sinkhorn2_variants_multi_b(nx): ub, bb, M_nx = nx.from_numpy(u, b, M) - G = ot.sinkhorn2(u, b, M, 1, method="sinkhorn", stopThr=1e-10) - Gl = nx.to_numpy( - ot.sinkhorn2(ub, bb, M_nx, 1, method="sinkhorn_log", stopThr=1e-10) - ) - G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method="sinkhorn", stopThr=1e-10)) - Gs = nx.to_numpy( - ot.sinkhorn2(ub, bb, M_nx, 1, method="sinkhorn_stabilized", stopThr=1e-10) - ) + G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn2( + ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn2( + ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn2( + ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -488,23 +436,16 @@ def test_sinkhorn_variants_log(): M = ot.dist(x, x) - G0, log0 = ot.sinkhorn(u, u, M, 1, method="sinkhorn", stopThr=1e-10, log=True) - Gl, logl = ot.sinkhorn(u, u, M, 1, method="sinkhorn_log", stopThr=1e-10, log=True) + G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', + stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn( + u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) Gs, logs = ot.sinkhorn( - u, u, M, 1, method="sinkhorn_stabilized", stopThr=1e-10, log=True - ) + u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) Ges, loges = ot.sinkhorn( - u, - u, - M, - 1, - method="sinkhorn_epsilon_scaling", - stopThr=1e-10, - log=True, - ) + u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,) G_green, loggreen = ot.sinkhorn( - u, u, M, 1, method="greenkhorn", stopThr=1e-10, log=True - ) + u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) @@ -526,43 +467,21 @@ def test_sinkhorn_variants_log_multib(verbose, warn): M = ot.dist(x, x) - G0, log0 = ot.sinkhorn(u, b, M, 1, method="sinkhorn", stopThr=1e-10, log=True) - Gl, logl = ot.sinkhorn( - u, - b, - M, - 1, - method="sinkhorn_log", - stopThr=1e-10, - log=True, - verbose=verbose, - warn=warn, - ) - Gs, logs = ot.sinkhorn( - u, - b, - M, - 1, - method="sinkhorn_stabilized", - stopThr=1e-10, - log=True, - verbose=verbose, - warn=warn, - ) + G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', + stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True, + verbose=verbose, warn=warn) + Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True, + verbose=verbose, warn=warn) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Gl, atol=1e-05) -@pytest.mark.parametrize( - "method, verbose, warn", - product( - ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], - [True, False], - [True, False], - ), -) +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], [True, False])) def test_barycenter(nx, method, verbose, warn): n_bins = 100 # nb bins @@ -589,11 +508,9 @@ def test_barycenter(nx, method, verbose, warn): else: # wasserstein bary_wass_np = ot.bregman.barycenter( - A, M, reg, weights, method=method, verbose=verbose, warn=warn - ) + A, M, reg, weights, method=method, verbose=verbose, warn=warn) bary_wass, _ = ot.bregman.barycenter( - A_nx, M_nx, reg, weights_nx, method=method, log=True - ) + A_nx, M_nx, reg, weights_nx, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass)) @@ -604,39 +521,33 @@ def test_barycenter(nx, method, verbose, warn): def test_free_support_sinkhorn_barycenter(): measures_locations = [ - np.array([-1.0]).reshape((1, 1)), # First dirac support - np.array([1.0]).reshape((1, 1)), # Second dirac support + np.array([-1.]).reshape((1, 1)), # First dirac support + np.array([1.]).reshape((1, 1)) # Second dirac support ] measures_weights = [ - np.array([1.0]), # First dirac sample weights - np.array([1.0]), # Second dirac sample weights + np.array([1.]), # First dirac sample weights + np.array([1.]) # Second dirac sample weights ] # Barycenter initialization - X_init = np.array([-12.0]).reshape((1, 1)) + X_init = np.array([-12.]).reshape((1, 1)) # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter - bar_locations = np.array([0.0]).reshape((1, 1)) + bar_locations = np.array([0.]).reshape((1, 1)) # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization # term to 1, but this should be, in general, fine-tuned to the problem. X = ot.bregman.free_support_sinkhorn_barycenter( - measures_locations, measures_weights, X_init, reg=1 - ) + measures_locations, measures_weights, X_init, reg=1) # Verifies if calculated barycenter matches ground-truth np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) -@pytest.mark.parametrize( - "method, verbose, warn", - product( - ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], - [True, False], - [True, False], - ), -) +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], [True, False])) def test_barycenter_assymetric_cost(nx, method, verbose, warn): n_bins = 20 # nb bins @@ -660,9 +571,9 @@ def test_barycenter_assymetric_cost(nx, method, verbose, warn): else: # wasserstein bary_wass_np = ot.bregman.barycenter( - A, M, reg, method=method, verbose=verbose, warn=warn - ) - bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True) + A, M, reg, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter( + A_nx, M_nx, reg, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass)) @@ -671,10 +582,9 @@ def test_barycenter_assymetric_cost(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, log=True) -@pytest.mark.parametrize( - "method, verbose, warn", - product(["sinkhorn", "sinkhorn_log"], [True, False], [True, False]), -) +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_log"], + [True, False], [True, False])) def test_barycenter_debiased(nx, method, verbose, warn): n_bins = 100 # nb bins @@ -698,26 +608,26 @@ def test_barycenter_debiased(nx, method, verbose, warn): reg = 1e-2 if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): - ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) + ot.bregman.barycenter_debiased( + A_nx, M_nx, reg, weights, method=method) else: - bary_wass_np = ot.bregman.barycenter_debiased( - A, M, reg, weights, method=method, verbose=verbose, warn=warn - ) + bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, + verbose=verbose, warn=warn) bary_wass, _ = ot.bregman.barycenter_debiased( - A_nx, M_nx, reg, weights_nx, method=method, log=True - ) + A_nx, M_nx, reg, weights_nx, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) - ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False) + ot.bregman.barycenter_debiased( + A_nx, M_nx, reg, log=True, verbose=False) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) def test_convergence_warning_barycenters(method): w = 10 - n_bins = w**2 # nb bins + n_bins = w ** 2 # nb bins # Gaussian distributions a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std @@ -736,17 +646,16 @@ def test_convergence_warning_barycenters(method): weights = np.array([1 - alpha, alpha]) reg = 0.1 with pytest.warns(UserWarning): - ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1) + ot.bregman.barycenter_debiased( + A, M, reg, weights, method=method, numItermax=1) with pytest.warns(UserWarning): ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1) with pytest.warns(UserWarning): - ot.bregman.convolutional_barycenter2d( - A_img, reg, weights, method=method, numItermax=1 - ) + ot.bregman.convolutional_barycenter2d(A_img, reg, weights, + method=method, numItermax=1) with pytest.warns(UserWarning): - ot.bregman.convolutional_barycenter2d_debiased( - A_img, reg, weights, method=method, numItermax=1 - ) + ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, weights, + method=method, numItermax=1) def test_barycenter_stabilization(nx): @@ -771,24 +680,15 @@ def test_barycenter_stabilization(nx): # wasserstein reg = 1e-2 bar_np = ot.bregman.barycenter( - A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True - ) - bar_stable = nx.to_numpy( - ot.bregman.barycenter( - A_nx, - M_nx, - reg, - weights_b, - method="sinkhorn_stabilized", - stopThr=1e-8, - verbose=True, - ) - ) - bar = nx.to_numpy( - ot.bregman.barycenter( - A_nx, M_nx, reg, weights_b, method="sinkhorn", stopThr=1e-8, verbose=True - ) - ) + A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) + bar_stable = nx.to_numpy(ot.bregman.barycenter( + A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized", + stopThr=1e-8, verbose=True + )) + bar = nx.to_numpy(ot.bregman.barycenter( + A_nx, M_nx, reg, weights_b, method="sinkhorn", + stopThr=1e-8, verbose=True + )) np.testing.assert_allclose(bar, bar_stable) np.testing.assert_allclose(bar, bar_np) @@ -1074,17 +974,6 @@ def test_wasserstein_bary_2d_debiased_device_tf(method): bary_wass = nx.to_numpy(bary_wass_b) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) - - # Test that the dtype and device are the same after the computation - nx.assert_same_dtype_device(Ab, bary_wass_b) - - # Check this only if there is a GPU - if len(tf.config.list_physical_devices("GPU")) > 0: - assert nx.dtype_device(bary_wass_b)[1].startswith("GPU") def test_unmix(nx): @@ -1112,13 +1001,15 @@ def test_unmix(nx): # wasserstein reg = 1e-3 um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) - um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) + um = nx.to_numpy(ot.bregman.unmix( + ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) np.testing.assert_allclose(um, um_np) - ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01, log=True, verbose=True) + ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, + 1, alpha=0.01, log=True, verbose=True) def test_empirical_sinkhorn(nx): @@ -1130,7 +1021,7 @@ def test_empirical_sinkhorn(nx): X_s = np.reshape(1.0 * np.arange(n), (n, 1)) X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric="euclidean") + M_m = ot.dist(X_s, X_t, metric='euclidean') ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) @@ -1142,27 +1033,27 @@ def test_empirical_sinkhorn(nx): sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric="euclidean")) + G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn( + X_sb, X_tb, 1, metric='euclidean')) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) + loss_emp_sinkhorn = nx.to_numpy( + ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints np.testing.assert_allclose( - sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05 - ) # metric sqeuclidian + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian np.testing.assert_allclose( - sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05 - ) # metric sqeuclidian - np.testing.assert_allclose(sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log - np.testing.assert_allclose(sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log np.testing.assert_allclose( - sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05 - ) # metric euclidian + sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log np.testing.assert_allclose( - sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05 - ) # metric euclidian + sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) @@ -1176,65 +1067,47 @@ def test_lazy_empirical_sinkhorn(nx): X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) X_t = np.reshape(np.arange(0, n, dtype=np.float64), (n, 1)) M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric="euclidean") + M_m = ot.dist(X_s, X_t, metric='euclidean') ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) f, g = ot.bregman.empirical_sinkhorn( - X_sb, - X_tb, - 1, - numIterMax=numIterMax, - isLazy=True, - batchSize=(1, 3), - verbose=True, - ) + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) f, g, log_es = ot.bregman.empirical_sinkhorn( - X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True - ) + X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) f, g = ot.bregman.empirical_sinkhorn( - X_sb, - X_tb, - 1, - metric="euclidean", - numIterMax=numIterMax, - isLazy=True, - batchSize=1, - ) + X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) f, g = nx.to_numpy(f), nx.to_numpy(g) G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2( - X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True - ) + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints np.testing.assert_allclose( - sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05 - ) # metric sqeuclidian + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian np.testing.assert_allclose( - sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05 - ) # metric sqeuclidian - np.testing.assert_allclose(sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log - np.testing.assert_allclose(sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian np.testing.assert_allclose( - sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05 - ) # metric euclidian + sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log np.testing.assert_allclose( - sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05 - ) # metric euclidian + sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) @@ -1250,27 +1123,27 @@ def test_empirical_sinkhorn_divergence(nx): M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t) + ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy( + a, b, X_s, X_t, M, M_s, M_t) emp_sinkhorn_div = nx.to_numpy( - ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb) - ) + ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) sinkhorn_div = nx.to_numpy( ot.sinkhorn2(ab, bb, M_nx, 1) - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) ) emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence( - X_s, X_t, 1, a=a, b=b - ) + X_s, X_t, 1, a=a, b=b) # check constraints - np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) np.testing.assert_allclose( - emp_sinkhorn_div, sinkhorn_div, atol=1e-05 - ) # cf conv emp sinkhorn + emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) + np.testing.assert_allclose( + emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn - ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True) + ot.bregman.empirical_sinkhorn_divergence( + X_sb, X_tb, 1, a=ab, b=bb, log=True) @pytest.mark.skipif(not torch, reason="No torch available") @@ -1293,8 +1166,7 @@ def test_empirical_sinkhorn_divergence_gradient(): X_tb.requires_grad = True emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence( - X_sb, X_tb, 1, a=ab, b=bb - ) + X_sb, X_tb, 1, a=ab, b=bb) emp_sinkhorn_div.backward() @@ -1323,12 +1195,14 @@ def test_stabilized_vs_sinkhorn_multidim(nx): ab, bb, M_nx = nx.from_numpy(a, b, M) - G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) - G, log = ot.bregman.sinkhorn( - ab, bb, M_nx, reg=epsilon, method="sinkhorn_stabilized", log=True - ) + G_np, _ = ot.bregman.sinkhorn( + a, b, M, reg=epsilon, method="sinkhorn", log=True) + G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, + method="sinkhorn_stabilized", + log=True) G = nx.to_numpy(G) - G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon, method="sinkhorn", log=True) + G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon, + method="sinkhorn", log=True) G2 = nx.to_numpy(G2) np.testing.assert_allclose(G_np, G2) @@ -1336,9 +1210,9 @@ def test_stabilized_vs_sinkhorn_multidim(nx): def test_implemented_methods(): - IMPLEMENTED_METHODS = ["sinkhorn", "sinkhorn_stabilized"] - ONLY_1D_methods = ["greenkhorn", "sinkhorn_epsilon_scaling"] - NOT_VALID_TOKENS = ["foo"] + IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] + ONLY_1D_methods = ['greenkhorn', 'sinkhorn_epsilon_scaling'] + NOT_VALID_TOKENS = ['foo'] # test generalized sinkhorn for unbalanced OT barycenter n = 3 rng = np.random.RandomState(42) @@ -1368,7 +1242,7 @@ def test_implemented_methods(): ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) -@pytest.skip_backend("tf") +@pytest.skip_backend('tf') @pytest.skip_backend("cupy") @pytest.skip_backend("jax") @pytest.mark.filterwarnings("ignore:Bottleneck") @@ -1387,9 +1261,8 @@ def test_screenkhorn(nx): # sinkhorn G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) # screenkhorn - G_screen = nx.to_numpy( - ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True) - ) + G_screen = nx.to_numpy(ot.bregman.screenkhorn( + ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) # check marginals np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) @@ -1425,21 +1298,16 @@ def test_sinkhorn_warmstart(): # Optimal plan with uniform warmstart pi_unif, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn", log=True, warmstart=None - ) + a, b, M, reg, method="sinkhorn", log=True, warmstart=None) # Optimal plan with warmstart generated from unregularized OT pi_sh, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn", log=True, warmstart=warmstart - ) + a, b, M, reg, method="sinkhorn", log=True, warmstart=warmstart) pi_sh_log, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn_log", log=True, warmstart=warmstart - ) + a, b, M, reg, method="sinkhorn_log", log=True, warmstart=warmstart) pi_sh_stab, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn_stabilized", log=True, warmstart=warmstart - ) + a, b, M, reg, method="sinkhorn_stabilized", log=True, warmstart=warmstart) pi_sh_sc, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn_epsilon_scaling", log=True, warmstart=warmstart - ) + a, b, M, reg, method="sinkhorn_epsilon_scaling", log=True, warmstart=warmstart) np.testing.assert_allclose(pi_unif, pi_sh, atol=1e-05) np.testing.assert_allclose(pi_unif, pi_sh_log, atol=1e-05) @@ -1463,17 +1331,14 @@ def test_empirical_sinkhorn_warmstart(): # Optimal plan with uniform warmstart f, g, _ = ot.bregman.empirical_sinkhorn( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None - ) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None) pi_unif = np.exp(f[:, None] + g[None, :] - M / reg) # Optimal plan with warmstart generated from unregularized OT f, g, _ = ot.bregman.empirical_sinkhorn( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart - ) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart) pi_ws_lazy = np.exp(f[:, None] + g[None, :] - M / reg) pi_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn( - X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart - ) + X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart) np.testing.assert_allclose(pi_unif, pi_ws_lazy, atol=1e-05) np.testing.assert_allclose(pi_unif, pi_ws_not_lazy, atol=1e-05) @@ -1495,15 +1360,12 @@ def test_empirical_sinkhorn_divergence_warmstart(): # Optimal plan with uniform warmstart sd_unif, _ = ot.bregman.empirical_sinkhorn_divergence( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None - ) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None) # Optimal plan with warmstart generated from unregularized OT sd_ws_lazy, _ = ot.bregman.empirical_sinkhorn_divergence( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart - ) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart) sd_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn_divergence( - X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart - ) + X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart) np.testing.assert_allclose(sd_unif, sd_ws_lazy, atol=1e-05) np.testing.assert_allclose(sd_unif, sd_ws_not_lazy, atol=1e-05) From 16987c85bb656e0d7916abcc37b4200a2d3c5a42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Tue, 17 Oct 2023 10:59:54 -0300 Subject: [PATCH 12/12] Update gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index b44ea43d7..dd5860a94 100644 --- a/.gitignore +++ b/.gitignore @@ -97,6 +97,7 @@ celerybeat-schedule # virtualenv venv/ ENV/ +.venv/ # Spyder project settings .spyderproject @@ -120,4 +121,4 @@ debug .vscode # pytest cahche -.pytest_cache \ No newline at end of file +.pytest_cache