Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ It provides the following solvers:
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat).
* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
* Non regularized free support Wasserstein barycenters [20].
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
* Optimal transport for domain adaptation with group lasso regularization [5]
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
Expand Down Expand Up @@ -225,3 +226,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](arXiv preprint arxiv:1605.08527). Advances in Neural Information Processing Systems (2016).

[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018)

[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
69 changes: 69 additions & 0 deletions examples/plot_free_support_barycenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
"""
====================================================
2D free support Wasserstein barycenters of distributions
====================================================

Illustration of 2D Wasserstein barycenters if discributions that are weighted
sum of diracs.

"""

# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
#
# License: MIT License

import numpy as np
import matplotlib.pylab as pl
import ot


##############################################################################
# Generate data
# -------------
#%% parameters and data generation
N = 3
d = 2
measures_locations = []
measures_weights = []

for i in range(N):

n_i = np.random.randint(low=1, high=20) # nb samples

mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean

A_i = np.random.rand(d, d)
cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix

x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations
b_i = np.random.uniform(0., 1., (n_i,))
b_i = b_i / np.sum(b_i) # Dirac weights

measures_locations.append(x_i)
measures_weights.append(b_i)


##############################################################################
# Compute free support barycenter
# -------------
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

---- needs to have the proper length for good documentation generation.


k = 10 # number of Diracs of the barycenter
X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)

X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)


##############################################################################
# Plot data
# ---------

pl.figure(1)
for (x_i, b_i) in zip(measures_locations, measures_weights):
color = np.random.randint(low=1, high=10 * N)
pl.scatter(x_i[:, 0], x_i[:, 1], s=b * 1000, label='input measure')
pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')
pl.title('Data measures and their barycenter')
pl.legend(loc=0)
pl.show()
95 changes: 94 additions & 1 deletion ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from .emd_wrap import emd_c, check_result
from ..utils import parmap
from .cvx import barycenter
from ..utils import dist

__all__=['emd', 'emd2', 'barycenter', 'cvx']
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx']


def emd(a, b, M, numItermax=100000, log=False):
Expand Down Expand Up @@ -216,3 +217,95 @@ def f(b):

res = parmap(f, [b[:, i] for i in range(nb)], processes)
return res



def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
"""
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)

The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
- we do not optimize over the weights
- we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.

Parameters
----------
measures_locations : list of (k_i,d) np.ndarray
The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
measures_weights : list of (k_i,) np.ndarray
Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure

X_init : (k,d) np.ndarray
Initialization of the support locations (on k atoms) of the barycenter
b : (k,) np.ndarray
Initialization of the weights of the barycenter (non-negatives, sum to 1)
weights : (k,) np.ndarray
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)

numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True

Returns
-------
X : (k,d) np.ndarray
Support locations (on k atoms) of the barycenter

References
----------

.. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.

.. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.

"""

iter_count = 0

N = len(measures_locations)
k = X_init.shape[0]
d = X_init.shape[1]
if b is None:
b = np.ones((k,))/k
if weights is None:
weights = np.ones((N,)) / N

X = X_init

log_dict = {}
displacement_square_norms = []

displacement_square_norm = stopThr + 1.

while ( displacement_square_norm > stopThr and iter_count < numItermax ):

T_sum = np.zeros((k, d))

for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):

M_i = dist(X, measure_locations_i)
T_i = emd(b, measure_weights_i, M_i)
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)

displacement_square_norm = np.sum(np.square(T_sum-X))
if log:
displacement_square_norms.append(displacement_square_norm)

X = T_sum

if verbose:
print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)

iter_count += 1

if log:
log_dict['displacement_square_norms'] = displacement_square_norms
return X, log_dict
else:
return X
3 changes: 2 additions & 1 deletion ot/lp/cvx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import scipy as sp
import scipy.sparse as sps


try:
import cvxopt
from cvxopt import solvers, matrix, spmatrix
Expand All @@ -26,7 +27,7 @@ def scipy_sparse_to_spmatrix(A):


def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
"""Compute the entropic regularized wasserstein barycenter of distributions A
"""Compute the Wasserstein barycenter of distributions A
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch !


The function solves the following optimization problem [16]:

Expand Down
15 changes: 15 additions & 0 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,21 @@ def test_lp_barycenter():
np.testing.assert_allclose(bary.sum(), 1)


def test_free_support_barycenter():

measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
measures_weights = [np.array([1.]), np.array([1.])]

X_init = np.array([-12.]).reshape((1, 1))

# obvious barycenter location between two diracs
bar_locations = np.array([0.]).reshape((1, 1))

X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init)

np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)


@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():

Expand Down