From a07687c1a148103bd986007559d6e7f27de6561d Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Fri, 6 May 2022 20:57:56 +0200 Subject: [PATCH 1/3] fix transpose in sinkhorn barycenters --- ot/bregman.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/bregman.py b/ot/bregman.py index c06af2fe4..34dcadb81 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1511,7 +1511,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, for ii in range(numItermax): - UKv = u * nx.dot(K, A / nx.dot(K, u)) + UKv = u * nx.dot(K.T, A / nx.dot(K, u)) u = (u.T * geometricBar(weights, UKv)).T / UKv if ii % 10 == 1: From fe49a660448ac84d77996aacac98b6c2b402e7cb Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Fri, 6 May 2022 20:58:09 +0200 Subject: [PATCH 2/3] add test for assymetric cost barycenters --- test/test_bregman.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/test_bregman.py b/test/test_bregman.py index 6c379844e..497a1d0cf 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -490,6 +490,43 @@ def test_barycenter(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, log=True) + +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter_assymetric_cost(nx, method, verbose, warn): + n_bins = 20 # nb bins + + # Gaussian distributions + A = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + + # creating matrix A containing all distributions + A = A[:, None] + + # assymetric loss matrix + normalization + rng = np.random.RandomState(42) + M = rng.randn(n_bins, n_bins) ** 2 + M /= M.max() + + A_nx, M_nx = nx.from_numpy(A, M) + reg = 1e-2 + + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter(A_nx, M_nx, reg, method=method) + else: + # wasserstein + bary_wass_np = ot.bregman.barycenter(A, M, reg, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) + + ot.bregman.barycenter(A_nx, M_nx, reg, log=True) + + + @pytest.mark.parametrize("method, verbose, warn", product(["sinkhorn", "sinkhorn_log"], [True, False], [True, False])) From d974f64a37fa3df69437c4d6a705518be975271c Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Fri, 6 May 2022 21:17:24 +0200 Subject: [PATCH 3/3] fix pep8 --- test/test_bregman.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 497a1d0cf..112bfca48 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -490,7 +490,6 @@ def test_barycenter(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, log=True) - @pytest.mark.parametrize("method, verbose, warn", product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], [True, False], [True, False])) @@ -526,7 +525,6 @@ def test_barycenter_assymetric_cost(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, log=True) - @pytest.mark.parametrize("method, verbose, warn", product(["sinkhorn", "sinkhorn_log"], [True, False], [True, False]))