From a1ab63ddf0ded728e0ef26297296d7e013504e84 Mon Sep 17 00:00:00 2001 From: Eloi Date: Tue, 3 May 2022 16:52:08 +0200 Subject: [PATCH 1/9] GWB first solver version --- README.md | 4 +- ot/lp/__init__.py | 116 +++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 107 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index e2b33d957..c9d838784 100644 --- a/README.md +++ b/README.md @@ -288,4 +288,6 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. -[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) \ No newline at end of file +[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) + +[42] DELON, Julie, GOZLAN, Nathael, et SAINT-DIZIER, Alexandre. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. \ No newline at end of file diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 390c32dfb..e7850cb5a 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -26,8 +26,6 @@ from ..utils import parmap from ..backend import get_backend - - __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] @@ -517,8 +515,8 @@ def f(b): log['warning'] = result_code_string log['result_code'] = result_code cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), (log['u'] - nx.mean(log['u']), - log['v'] - nx.mean(log['v']), G)) + (a0, b0, M0), (log['u'] - nx.mean(log['u']), + log['v'] - nx.mean(log['v']), G)) return [cost, log] else: def f(b): @@ -629,7 +627,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None """ - nx = get_backend(*measures_locations,*measures_weights,X_init) + nx = get_backend(*measures_locations, *measures_weights, X_init) iter_count = 0 @@ -637,9 +635,9 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None k = X_init.shape[0] d = X_init.shape[1] if b is None: - b = nx.ones((k,),type_as=X_init) / k + b = nx.ones((k,), type_as=X_init) / k if weights is None: - weights = nx.ones((N,),type_as=X_init) / N + weights = nx.ones((N,), type_as=X_init) / N X = X_init @@ -650,15 +648,14 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None while (displacement_square_norm > stopThr and iter_count < numItermax): - T_sum = nx.zeros((k, d),type_as=X_init) - + T_sum = nx.zeros((k, d), type_as=X_init) - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): M_i = dist(X, measure_locations_i) T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) - T_sum = T_sum + weight_i * 1. / b[:,None] * nx.dot(T_i, measure_locations_i) + T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i) - displacement_square_norm = nx.sum((T_sum - X)**2) + displacement_square_norm = nx.sum((T_sum - X) ** 2) if log: displacement_square_norms.append(displacement_square_norm) @@ -675,3 +672,98 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None else: return X + +def generalised_free_support_barycenter(X, a, P, L, Y_init=None, b=None, weights=None, + numItermax=100, stopThr=1e-7, verbose=False, log=None, numThreads=1): + r""" + Solves the free support (locations of the barycenters are optimized, not the weights) + generalised Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance, + with linear maps), formally: + + .. math:: + \min_\gamma \quad \sum_{i=1}^N w_i W_2^2(\nu_i, P_i\#\gamma) + + where : + + - :math:`\gamma = \sum_[l=1}^L b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d` + - :math:`\mathbf{b} \in \mathbb{R}^{L}` is the desired weights vector of the barycenter + - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}` + - the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i` + - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the empirical measures atoms locations + - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex) + - Each :math:`P_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}` + + As show by :ref:`[42]`, this problem can be re-written as a Wasserstein Barycenter problem, + which we solve using the free support method :ref:`[20] ` + (Algorithm 2). + + Parameters + ---------- + X : list of p (k_i,d_i) array-like + Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space + (:math:`k_i` can be different for each element of the list) + a : list of p (k_i,) array-like + Measure weights: each element is a vector (k_i) on the simplex + P : list of p (d_i,d) array-like + Each :math: `P_i` is a linear map `\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}` + L : int + Number of barycenter points + Y_init : (L,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + b : (L,) array-like + Initialization of the weights of the barycenter (non-negatives, sum to 1) + weights : (p,) array-like + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) + + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + + + Returns + ------- + X : (L,d) array-like + Support locations (on L atoms) of the barycenter + + + .. _references-generalised-free-support-barycenter: + References + ---------- + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [42] DELON, Julie, GOZLAN, Nathael, et SAINT-DIZIER, Alexandre. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021. + + """ + nx = get_backend(*X, *a, *P) + d = P[0].shape[1] + A = nx.zeros((d, d), type_as=X[0]) # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB) + for (P_i, lambda_i) in zip(P, weights): + A = A + lambda_i * P_i.T @ P_i + B = nx.inv(nx.sqrtm(A)) + + Z = [x @ Pi @ B.T for (x, Pi) in zip(X, P)] # change of variables -> (WB) problem on Z + + if Y_init is None: + Y_init = nx.randn((L, d), type_as=X[0]) + + if b is None: + b = nx.ones(L, type_as=X[0]) / L # not optimised + + out = free_support_barycenter(Z, a, Y_init, b, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, numThreads=numThreads) + + bar, log_dict = out if log else out, None # if log is not None the out = (bar, log_dict) + bar = bar @ B.T # return to the Generalised WB formulation + + if log: + return bar, log_dict + else: + return bar From 1446081345fdab614bbc350678e97dca632f80fb Mon Sep 17 00:00:00 2001 From: Eloi Date: Wed, 4 May 2022 10:16:58 +0200 Subject: [PATCH 2/9] tests + example for gwb (untested) + free_bar doc fix --- README.md | 4 +- ...lot_generalised_free_support_barycenter.py | 71 +++++++++++++++++++ ot/lp/__init__.py | 14 ++-- test/test_ot.py | 39 ++++++++++ 4 files changed, 120 insertions(+), 8 deletions(-) create mode 100644 examples/barycenters/plot_generalised_free_support_barycenter.py diff --git a/README.md b/README.md index c9d838784..af33b9962 100644 --- a/README.md +++ b/README.md @@ -290,4 +290,6 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) -[42] DELON, Julie, GOZLAN, Nathael, et SAINT-DIZIER, Alexandre. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. \ No newline at end of file +[42] DELON, Julie, GOZLAN, Nathael, et SAINT-DIZIER, Alexandre. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. + +[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. \ No newline at end of file diff --git a/examples/barycenters/plot_generalised_free_support_barycenter.py b/examples/barycenters/plot_generalised_free_support_barycenter.py new file mode 100644 index 000000000..02b2b20f3 --- /dev/null +++ b/examples/barycenters/plot_generalised_free_support_barycenter.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +""" +======================================= +Generalised Wasserstein Barycenter Demo +======================================= + +This example illustrates the computation of Generalised Wasserstein Barycenter +as proposed in [42]. + + +[42] DELON, Julie, GOZLAN, Nathael, et SAINT-DIZIER, Alexandre. +Generalized Wasserstein barycenters between probability measures living on different subspaces. +arXiv preprint arXiv:2105.09755, 2021. + +""" + +# Author: Eloi Tanguy +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.pylab as pl +import ot +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D + +############################################################################## +# Generate data +# ------------- + +# Input measures +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2] +I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2] +I3 = pl.imread('../../data/heart.png').astype(np.float64)[::4, ::4, 2] + +sz = I1.shape[0] +UU, VV = np.meshgrid(np.arange(sz), np.arange(sz)) + +# Input measure locations in their respective 2D spaces +X = [np.stack((UU[I == 0], VV[I == 0]), 1) * 1.0 for I in [I1, I2, I3]] + +# Input measure weights +a = [ot.unif(x.shape[0]) for x in X] + +# Projections 3D -> 2D +P1 = np.array([[1,0,0],[0,1,0]]) +P2 = np.array([[0,1,0],[0,0,1]]) +P3 = np.array([[1,0,0],[0,0,1]]) +P = [P1,P2,P3] + +# Barycenter weights +weights = np.array([1/3, 1/3, 1/3]) + +# Number of barycenter points to compute +L = 500 + +############################################################################## +# Barycenter computation and plot +# ---------------------- +bar = ot.lp.generalized_free_support_barycenter(X, a, P, L) + +X_visu = [x @ Pi for (x, Pi) in zip(X, P)] # send measures to the global space for visu +fig = plt.figure(figsize=(7, 7)) +axis = fig.add_subplot(1, 1, 1, projection="3d") +for x in X_visu: + axis.scatter(x[:, 0], x[:, 1], x[:, 2], marker='.', alpha=.6) +axis.scatter(bar[:, 0], bar[:, 1], bar[:, 2], marker='.', alpha=.6) +plt.show() diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index e7850cb5a..f57afcca3 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -27,7 +27,7 @@ from ..backend import get_backend __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', - 'emd_1d', 'emd2_1d', 'wasserstein_1d'] + 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter'] def check_number_threads(numThreads): @@ -574,14 +574,14 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter - This problem is considered in :ref:`[1] ` (Algorithm 2). + This problem is considered in :ref:`[20] ` (Algorithm 2). There are two differences with the following codes: - we do not optimize over the weights - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in - :ref:`[1] ` (Algorithm 2). This can be seen as a discrete + :ref:`[20] ` (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of - :ref:`[2] ` proposed in the continuous setting. + :ref:`[43] ` proposed in the continuous setting. Parameters ---------- @@ -621,9 +621,9 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None .. _references-free-support-barycenter: References ---------- - .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. - .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. """ @@ -673,7 +673,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X -def generalised_free_support_barycenter(X, a, P, L, Y_init=None, b=None, weights=None, +def generalized_free_support_barycenter(X, a, P, L, Y_init=None, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None, numThreads=1): r""" Solves the free support (locations of the barycenters are optimized, not the weights) diff --git a/test/test_ot.py b/test/test_ot.py index bf832f6cd..a2761a964 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -320,6 +320,45 @@ def test_free_support_barycenter_backends(nx): np.testing.assert_allclose(X, nx.to_numpy(X2)) +def test_generalised_free_support_barycenter(): + X = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + a = [np.array([1.]), np.array([1.])] + + P = [np.array([1]), np.array([1])] + + Y_init = np.array([-12.]).reshape((1, 1)) + + # obvious barycenter location between two diracs + Y_true = np.array([0.]).reshape((1, 1)) + + # test without log and no init + Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1) + np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7) + + # test with log and init + Y, _ = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init, b=np.array([1.])) + np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7) + + +def test_generalised_free_support_barycenter_backends(nx): + + X = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + a = [np.array([1.]), np.array([1.])] + P = [np.array([1]), np.array([1])] + Y_init = np.array([-12.]).reshape((1, 1)) + + Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init) + + X2 = nx.from_numpy(*X) + a2 = nx.from_numpy(*a) + P2 = nx.from_numpy(*P) + Y_init2 = nx.from_numpy(Y_init) + + Y2 = ot.lp.generalized_free_support_barycenter(X2, a2, P2, 1, Y_init=Y_init2) + + np.testing.assert_allclose(Y, nx.to_numpy(Y2)) + + @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None] From d875aaf3156c9f9e5e440f09e7dafa5c568d8a98 Mon Sep 17 00:00:00 2001 From: Eloi Date: Thu, 5 May 2022 10:49:02 +0200 Subject: [PATCH 3/9] improved doc, fixed minor bugs, better example visu --- CONTRIBUTORS.md | 1 + RELEASES.md | 6 ++ ...lot_generalised_free_support_barycenter.py | 75 ++++++++++++----- ot/__init__.py | 2 +- ot/lp/__init__.py | 84 +++++++++++-------- test/test_ot.py | 17 ++-- 6 files changed, 121 insertions(+), 64 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index ab64fba72..0909b1402 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -37,6 +37,7 @@ The contributors to this library are: * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) * [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) +* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) ## Acknowledgments diff --git a/RELEASES.md b/RELEASES.md index be2192eb7..3f988fce9 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,11 @@ # Releases +## 0.8.3dev + +#### New features + +- Added Generalized Wasserstein Barycenter solver + example + ## 0.8.2 diff --git a/examples/barycenters/plot_generalised_free_support_barycenter.py b/examples/barycenters/plot_generalised_free_support_barycenter.py index 02b2b20f3..41e44492f 100644 --- a/examples/barycenters/plot_generalised_free_support_barycenter.py +++ b/examples/barycenters/plot_generalised_free_support_barycenter.py @@ -18,54 +18,91 @@ # # License: MIT License -# sphinx_gallery_thumbnail_number = 1 +# sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pyplot as plt import matplotlib.pylab as pl import ot -# necessary for 3d plot even if not used -from mpl_toolkits.mplot3d import Axes3D ############################################################################## -# Generate data +# Generate and plot data # ------------- # Input measures -I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2] -I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2] -I3 = pl.imread('../../data/heart.png').astype(np.float64)[::4, ::4, 2] +sub_sample_factor = 8 +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] +I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] +I3 = pl.imread('../../data/heart.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] sz = I1.shape[0] UU, VV = np.meshgrid(np.arange(sz), np.arange(sz)) # Input measure locations in their respective 2D spaces -X = [np.stack((UU[I == 0], VV[I == 0]), 1) * 1.0 for I in [I1, I2, I3]] +X_list = [np.stack((UU[I == 0], VV[I == 0]), 1) * 1.0 for I in [I1, I2, I3]] # Input measure weights -a = [ot.unif(x.shape[0]) for x in X] +a_list = [ot.unif(x.shape[0]) for x in X_list] # Projections 3D -> 2D P1 = np.array([[1,0,0],[0,1,0]]) P2 = np.array([[0,1,0],[0,0,1]]) P3 = np.array([[1,0,0],[0,0,1]]) -P = [P1,P2,P3] +P_list = [P1,P2,P3] # Barycenter weights weights = np.array([1/3, 1/3, 1/3]) # Number of barycenter points to compute -L = 500 +n_samples_bary = 100 -############################################################################## -# Barycenter computation and plot -# ---------------------- -bar = ot.lp.generalized_free_support_barycenter(X, a, P, L) +# Send the input measures into 3D space for visualisation +X_visu = [Xi @ Pi for (Xi, Pi) in zip(X_list, P_list)] -X_visu = [x @ Pi for (x, Pi) in zip(X, P)] # send measures to the global space for visu +# Plot the input data fig = plt.figure(figsize=(7, 7)) axis = fig.add_subplot(1, 1, 1, projection="3d") -for x in X_visu: - axis.scatter(x[:, 0], x[:, 1], x[:, 2], marker='.', alpha=.6) -axis.scatter(bar[:, 0], bar[:, 1], bar[:, 2], marker='.', alpha=.6) +for Xi in X_visu: + axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +axis.view_init(azim=45) +axis.set_xticks([]) +axis.set_yticks([]) +axis.set_zticks([]) +plt.show() + +############################################################################## +# Barycenter computation and plot +# ---------------------- +Y = ot.lp.generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary) + +fig = plt.figure(figsize=(9, 3)) + +ax = fig.add_subplot(1, 3, 1, projection='3d') +for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +ax.view_init(elev=0, azim=0) +ax.set_xticks([]) +ax.set_yticks([]) +ax.set_zticks([]) + +ax = fig.add_subplot(1, 3, 2, projection='3d') +for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +ax.view_init(elev=0, azim=90) +ax.set_xticks([]) +ax.set_yticks([]) +ax.set_zticks([]) + +ax = fig.add_subplot(1, 3, 3, projection='3d') +for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +ax.view_init(elev=90, azim=0) +ax.set_xticks([]) +ax.set_yticks([]) +ax.set_zticks([]) + +plt.tight_layout() plt.show() diff --git a/ot/__init__.py b/ot/__init__.py index 86ed94eb4..15d83515e 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -51,7 +51,7 @@ # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.8.2" +__version__ = "0.8.3dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index f57afcca3..bfad1719b 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -570,8 +570,8 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None where : - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one - - the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i` - - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations + - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) + - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter This problem is considered in :ref:`[20] ` (Algorithm 2). @@ -673,25 +673,25 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X -def generalized_free_support_barycenter(X, a, P, L, Y_init=None, b=None, weights=None, - numItermax=100, stopThr=1e-7, verbose=False, log=None, numThreads=1): +def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, Y_init=None, b=None, weights=None, + numItermax=100, stopThr=1e-7, verbose=False, log=None, numThreads=1, eps=0): r""" - Solves the free support (locations of the barycenters are optimized, not the weights) - generalised Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance, - with linear maps), formally: + Solves the free support generalised Wasserstein barycenter problem: finding a barycenter (a discrete measure with + a fixed amount of points of uniform weights) whose respective projections fit the input measures. + More formally: .. math:: - \min_\gamma \quad \sum_{i=1}^N w_i W_2^2(\nu_i, P_i\#\gamma) + \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma) where : - - :math:`\gamma = \sum_[l=1}^L b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d` - - :math:`\mathbf{b} \in \mathbb{R}^{L}` is the desired weights vector of the barycenter + - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d` + - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}` - - the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i` - - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the empirical measures atoms locations + - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex) + - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex) - - Each :math:`P_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}` + - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}` As show by :ref:`[42]`, this problem can be re-written as a Wasserstein Barycenter problem, which we solve using the free support method :ref:`[20] ` @@ -699,22 +699,21 @@ def generalized_free_support_barycenter(X, a, P, L, Y_init=None, b=None, weights Parameters ---------- - X : list of p (k_i,d_i) array-like + X_list : list of p (k_i,d_i) array-like Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space (:math:`k_i` can be different for each element of the list) - a : list of p (k_i,) array-like + a_list : list of p (k_i,) array-like Measure weights: each element is a vector (k_i) on the simplex - P : list of p (d_i,d) array-like + P_list : list of p (d_i,d) array-like Each :math: `P_i` is a linear map `\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}` - L : int + n_samples_bary : int Number of barycenter points - Y_init : (L,d) array-like + Y_init : (n_samples_bary,d) array-like Initialization of the support locations (on `k` atoms) of the barycenter - b : (L,) array-like - Initialization of the weights of the barycenter (non-negatives, sum to 1) + b : (n_samples_bary,) array-like + Initialization of the weights of the barycenter measure (on the simplex) weights : (p,) array-like - Initialization of the coefficients of the barycenter (non-negatives, sum to 1) - + Initialization of the coefficients of the barycenter (on the simplex) numItermax : int, optional Max number of iterations stopThr : float, optional @@ -726,12 +725,15 @@ def generalized_free_support_barycenter(X, a, P, L, Y_init=None, b=None, weights numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) If compiled with OpenMP, chooses the number of threads to parallelize. "max" selects the highest number possible. + eps: Stability coefficient for the change of variable matrix inversion + If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix + inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense) Returns ------- - X : (L,d) array-like - Support locations (on L atoms) of the barycenter + Y : (n_samples_bary,d) array-like + Support locations (on n_samples_bary atoms) of the barycenter .. _references-generalised-free-support-barycenter: @@ -742,28 +744,38 @@ def generalized_free_support_barycenter(X, a, P, L, Y_init=None, b=None, weights .. [42] DELON, Julie, GOZLAN, Nathael, et SAINT-DIZIER, Alexandre. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021. """ - nx = get_backend(*X, *a, *P) - d = P[0].shape[1] - A = nx.zeros((d, d), type_as=X[0]) # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB) - for (P_i, lambda_i) in zip(P, weights): + nx = get_backend(*X_list, *a_list, *P_list) + d = P_list[0].shape[1] + p = len(X_list) + + if weights is None: + weights = nx.ones(p, type_as=X_list[0]) / p + + # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB) + A = eps * nx.eye(d, type_as=X_list[0]) # if eps nonzero: will force the invertibility of A + for (P_i, lambda_i) in zip(P_list, weights): A = A + lambda_i * P_i.T @ P_i B = nx.inv(nx.sqrtm(A)) - Z = [x @ Pi @ B.T for (x, Pi) in zip(X, P)] # change of variables -> (WB) problem on Z + Z_list = [x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list)] # change of variables -> (WB) problem on Z if Y_init is None: - Y_init = nx.randn((L, d), type_as=X[0]) + Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0]) if b is None: - b = nx.ones(L, type_as=X[0]) / L # not optimised + b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimised - out = free_support_barycenter(Z, a, Y_init, b, numItermax=numItermax, + out = free_support_barycenter(Z_list, a_list, Y_init, b, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, numThreads=numThreads) - bar, log_dict = out if log else out, None # if log is not None the out = (bar, log_dict) - bar = bar @ B.T # return to the Generalised WB formulation + if log: # unpack + Y, log_dict = out + else: + Y = out + log_dict = None + Y = Y @ B.T # return to the Generalised WB formulation if log: - return bar, log_dict + return Y, log_dict else: - return bar + return Y diff --git a/test/test_ot.py b/test/test_ot.py index a2761a964..ba3ef6aca 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -321,30 +321,31 @@ def test_free_support_barycenter_backends(nx): def test_generalised_free_support_barycenter(): - X = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + np.random.seed(42) # random inits + X = [np.array([-1., -1.]).reshape((1, 2)), np.array([1., 1.]).reshape((1, 2))] # two 2D points bar is obviously 0 a = [np.array([1.]), np.array([1.])] - P = [np.array([1]), np.array([1])] + P = [np.eye(2), np.eye(2)] - Y_init = np.array([-12.]).reshape((1, 1)) + Y_init = np.array([-12., 7.]).reshape((1, 2)) - # obvious barycenter location between two diracs - Y_true = np.array([0.]).reshape((1, 1)) + # obvious barycenter location between two 2D diracs + Y_true = np.array([0., .0]).reshape((1, 2)) # test without log and no init Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1) np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7) # test with log and init - Y, _ = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init, b=np.array([1.])) + Y, _ = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init, b=np.array([1.]), log=True) np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7) def test_generalised_free_support_barycenter_backends(nx): - + np.random.seed(42) X = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] a = [np.array([1.]), np.array([1.])] - P = [np.array([1]), np.array([1])] + P = [np.array([1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] Y_init = np.array([-12.]).reshape((1, 1)) Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init) From 3205f8fcb2506a95dba1ea8fbb29d9a20d7fbfea Mon Sep 17 00:00:00 2001 From: Eloi Date: Thu, 5 May 2022 11:36:26 +0200 Subject: [PATCH 4/9] minor doc + visu fixes --- .../barycenters/plot_generalised_free_support_barycenter.py | 2 +- ot/lp/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/barycenters/plot_generalised_free_support_barycenter.py b/examples/barycenters/plot_generalised_free_support_barycenter.py index 41e44492f..53391ad38 100644 --- a/examples/barycenters/plot_generalised_free_support_barycenter.py +++ b/examples/barycenters/plot_generalised_free_support_barycenter.py @@ -60,7 +60,7 @@ X_visu = [Xi @ Pi for (Xi, Pi) in zip(X_list, P_list)] # Plot the input data -fig = plt.figure(figsize=(7, 7)) +fig = plt.figure(figsize=(3, 3)) axis = fig.add_subplot(1, 1, 1, projection="3d") for Xi in X_visu: axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index bfad1719b..1a79b52ce 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -705,7 +705,7 @@ def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, a_list : list of p (k_i,) array-like Measure weights: each element is a vector (k_i) on the simplex P_list : list of p (d_i,d) array-like - Each :math: `P_i` is a linear map `\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}` + Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}` n_samples_bary : int Number of barycenter points Y_init : (n_samples_bary,d) array-like From c51ab12003b6e215b6dcfff210b281aa0f0be264 Mon Sep 17 00:00:00 2001 From: Eloi Date: Thu, 5 May 2022 13:17:42 +0200 Subject: [PATCH 5/9] plot GWB pep8 fix --- .../plot_generalised_free_support_barycenter.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/barycenters/plot_generalised_free_support_barycenter.py b/examples/barycenters/plot_generalised_free_support_barycenter.py index 53391ad38..bfcfd69f5 100644 --- a/examples/barycenters/plot_generalised_free_support_barycenter.py +++ b/examples/barycenters/plot_generalised_free_support_barycenter.py @@ -39,19 +39,19 @@ UU, VV = np.meshgrid(np.arange(sz), np.arange(sz)) # Input measure locations in their respective 2D spaces -X_list = [np.stack((UU[I == 0], VV[I == 0]), 1) * 1.0 for I in [I1, I2, I3]] +X_list = [np.stack((UU[im == 0], VV[im == 0]), 1) * 1.0 for im in [I1, I2, I3]] # Input measure weights a_list = [ot.unif(x.shape[0]) for x in X_list] # Projections 3D -> 2D -P1 = np.array([[1,0,0],[0,1,0]]) -P2 = np.array([[0,1,0],[0,0,1]]) -P3 = np.array([[1,0,0],[0,0,1]]) -P_list = [P1,P2,P3] +P1 = np.array([[1, 0, 0], [0, 1, 0]]) +P2 = np.array([[0, 1, 0], [0, 0, 1]]) +P3 = np.array([[1, 0, 0], [0, 0, 1]]) +P_list = [P1, P2, P3] # Barycenter weights -weights = np.array([1/3, 1/3, 1/3]) +weights = np.array([1 / 3, 1 / 3, 1 / 3]) # Number of barycenter points to compute n_samples_bary = 100 From 833ed9fe34d2e71ecc635c5a64a9e9c0ce92329a Mon Sep 17 00:00:00 2001 From: Eloi Date: Thu, 5 May 2022 13:28:54 +0200 Subject: [PATCH 6/9] fixed partial gromov test reproductibility --- test/test_partial.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_partial.py b/test/test_partial.py index 97c611b1b..e07377b60 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -137,6 +137,7 @@ def test_partial_wasserstein(): def test_partial_gromov_wasserstein(): + np.random.seed(42) n_samples = 20 # nb samples n_noise = 10 # nb of samples (noise) From 15bad5aeaa651662133027e420d4e12427949d99 Mon Sep 17 00:00:00 2001 From: Eloi Date: Thu, 5 May 2022 13:56:14 +0200 Subject: [PATCH 7/9] added an animation for the GWB visu --- ...ot_generalized_free_support_barycenter.py} | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) rename examples/barycenters/{plot_generalised_free_support_barycenter.py => plot_generalized_free_support_barycenter.py} (81%) diff --git a/examples/barycenters/plot_generalised_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py similarity index 81% rename from examples/barycenters/plot_generalised_free_support_barycenter.py rename to examples/barycenters/plot_generalized_free_support_barycenter.py index bfcfd69f5..ee10b4751 100644 --- a/examples/barycenters/plot_generalised_free_support_barycenter.py +++ b/examples/barycenters/plot_generalized_free_support_barycenter.py @@ -24,6 +24,7 @@ import matplotlib.pyplot as plt import matplotlib.pylab as pl import ot +import matplotlib.animation as animation ############################################################################## # Generate and plot data @@ -106,3 +107,29 @@ plt.tight_layout() plt.show() + +############################################## +# Rotation animation +# -------------------------------------------- + +fig = plt.figure(figsize=(7, 7)) +ax = fig.add_subplot(1, 1, 1, projection="3d") + + +def _init(): + for Xi in X_visu: + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) + ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) + ax.view_init(elev=0, azim=0) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_zticks([]) + return fig, + + +def _update_plot(i): + ax.view_init(elev=i, azim=4 * i) + return fig, + + +ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=90, interval=50, blit=True, repeat_delay=2000) From 03289661bc0a94d55fb7feed5037569238c39225 Mon Sep 17 00:00:00 2001 From: Eloi Date: Thu, 5 May 2022 14:46:46 +0200 Subject: [PATCH 8/9] added PR num --- RELEASES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 3f988fce9..346183238 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,7 +4,7 @@ #### New features -- Added Generalized Wasserstein Barycenter solver + example +- Added Generalized Wasserstein Barycenter solver + example (PR #372) ## 0.8.2 From bc6534c60c55ed6d749fc7e8cdf1678a133c2b0c Mon Sep 17 00:00:00 2001 From: Eloi Date: Thu, 5 May 2022 16:30:53 +0200 Subject: [PATCH 9/9] minor doc fixes + better gwb logo --- README.md | 2 +- .../plot_sliced_wass_grad_flow_pytorch.py | 6 ++-- ...lot_generalized_free_support_barycenter.py | 33 ++++++++++++++----- ot/lp/__init__.py | 13 ++++---- 4 files changed, 36 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index af33b9962..12340d525 100644 --- a/README.md +++ b/README.md @@ -290,6 +290,6 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) -[42] DELON, Julie, GOZLAN, Nathael, et SAINT-DIZIER, Alexandre. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. +[42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. [43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. \ No newline at end of file diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py index cf5d64d54..59e004289 100644 --- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py +++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py @@ -103,7 +103,7 @@ # %% # Animate trajectories of the gradient flow along iteration -# ------------------------------------------------------- +# --------------------------------------------------------- pl.figure(3, (8, 4)) @@ -122,7 +122,7 @@ def _update_plot(i): # %% # Compute the Sliced Wasserstein Barycenter -# +# ----------------------------------------- x1_torch = torch.tensor(x1).to(device=device) x3_torch = torch.tensor(x3).to(device=device) xbinit = np.random.randn(500, 2) * 10 + 16 @@ -169,7 +169,7 @@ def _update_plot(i): # %% # Animate trajectories of the barycenter along gradient descent -# ------------------------------------------------------- +# ------------------------------------------------------------- pl.figure(5, (8, 4)) diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py index ee10b4751..9af1953fd 100644 --- a/examples/barycenters/plot_generalized_free_support_barycenter.py +++ b/examples/barycenters/plot_generalized_free_support_barycenter.py @@ -1,14 +1,14 @@ # -*- coding: utf-8 -*- """ ======================================= -Generalised Wasserstein Barycenter Demo +Generalized Wasserstein Barycenter Demo ======================================= -This example illustrates the computation of Generalised Wasserstein Barycenter +This example illustrates the computation of Generalized Wasserstein Barycenter as proposed in [42]. -[42] DELON, Julie, GOZLAN, Nathael, et SAINT-DIZIER, Alexandre. +[42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021. @@ -26,9 +26,9 @@ import ot import matplotlib.animation as animation -############################################################################## +######################## # Generate and plot data -# ------------- +# ---------------------- # Input measures sub_sample_factor = 8 @@ -55,7 +55,7 @@ weights = np.array([1 / 3, 1 / 3, 1 / 3]) # Number of barycenter points to compute -n_samples_bary = 100 +n_samples_bary = 150 # Send the input measures into 3D space for visualisation X_visu = [Xi @ Pi for (Xi, Pi) in zip(X_list, P_list)] @@ -71,10 +71,27 @@ axis.set_zticks([]) plt.show() -############################################################################## +################################# # Barycenter computation and plot -# ---------------------- +# ------------------------------- + Y = ot.lp.generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary) +fig = plt.figure(figsize=(3, 3)) + +axis = fig.add_subplot(1, 1, 1, projection="3d") +for Xi in X_visu: + axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) +axis.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) +axis.view_init(azim=45) +axis.set_xticks([]) +axis.set_yticks([]) +axis.set_zticks([]) +plt.show() + + +############################# +# Plotting projection matches +# --------------------------- fig = plt.figure(figsize=(9, 3)) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 1a79b52ce..572781dba 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -623,7 +623,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None ---------- .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. - .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. """ @@ -693,8 +693,9 @@ def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex) - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}` - As show by :ref:`[42]`, this problem can be re-written as a Wasserstein Barycenter problem, - which we solve using the free support method :ref:`[20] ` + As show by :ref:`[42] `, + this problem can be re-written as a Wasserstein Barycenter problem, + which we solve using the free support method :ref:`[20] ` (Algorithm 2). Parameters @@ -736,12 +737,12 @@ def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, Support locations (on n_samples_bary atoms) of the barycenter - .. _references-generalised-free-support-barycenter: + .. _references-generalized-free-support-barycenter: References ---------- - .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. - .. [42] DELON, Julie, GOZLAN, Nathael, et SAINT-DIZIER, Alexandre. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021. + .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021. """ nx = get_backend(*X_list, *a_list, *P_list)