Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ It provides the following solvers:
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat).
* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
* Non regularized free support Wasserstein barycenters [20].
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
* Optimal transport for domain adaptation with group lasso regularization [5]
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
Expand Down Expand Up @@ -225,3 +226,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](arXiv preprint arxiv:1605.08527). Advances in Neural Information Processing Systems (2016).

[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018)

[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
70 changes: 70 additions & 0 deletions examples/plot_free_support_barycenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
"""
====================================================
2D Wasserstein barycenters of distributions
====================================================

Illustration of 2D Wasserstein barycenters if discributions that are weighted
sum of diracs.

"""

# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
#
# License: MIT License

import numpy as np
import matplotlib.pylab as pl
import ot.plot

##############################################################################
# Generate data
# -------------
#%% parameters and data generation
N = 3
d = 2
measures_locations = []
measures_weights = []

for i in range(N):

n = np.random.randint(low=1, high=20) # nb samples

mu = np.random.normal(0., 4., (d,))

A = np.random.rand(d, d)
cov = np.dot(A, A.transpose())

x_i = ot.datasets.make_2D_samples_gauss(n, mu, cov)
b_i = np.random.uniform(0., 1., (n,))
b_i = b_i / np.sum(b_i)

measures_locations.append(x_i)
measures_weights.append(b_i)


##############################################################################
# Compute free support barycenter
# -------------
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

---- needs to have the proper length for good documentation generation.


k = 10
X_init = np.random.normal(0., 1., (k, d))
b = np.ones((k,)) / k

X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ot.lp.cvx.free_support_barycenter is very long.

you should import the function in ot.lp __init__.py and add it to __all__ like barycenter so that you can do ot.lp.free_support_barycenter



##############################################################################
# Plot data
# ---------

#%% plot samples

pl.figure(1)
for (x_i, b_i) in zip(measures_locations, measures_weights):
color = np.random.randint(low=1, high=10 * N)
pl.scatter(x_i[:, 0], x_i[:, 1], s=b * 1000, label='input measure')
pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')
pl.title('Data measures and their barycenter')
pl.legend(loc=0)
pl.show()
83 changes: 82 additions & 1 deletion ot/lp/cvx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import scipy as sp
import scipy.sparse as sps
import ot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you shouldn't import pot inside a module.

something with relative path like

from .__init__ import emd

is far better since it imports the emd function from the __init__.py


try:
import cvxopt
Expand All @@ -26,7 +27,7 @@ def scipy_sparse_to_spmatrix(A):


def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
"""Compute the entropic regularized wasserstein barycenter of distributions A
"""Compute the Wasserstein barycenter of distributions A
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch !


The function solves the following optimization problem [16]:

Expand Down Expand Up @@ -144,3 +145,83 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
return b, sol
else:
return b


def free_support_barycenter(measures_locations, measures_weights, X_init, b, weights=None, numItermax=100, stopThr=1e-6, verbose=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also do b=None if the weights are supposed uniform (needs test an initialization in the function)

"""
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)

The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
- we do not optimize over the weights
- we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.

Parameters
----------
data_positions : list of (k_i,d) np.ndarray
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

names in the documentation different from the code : data_positions vs measures_locations

The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
data_weights : list of (k_i,) np.ndarray
Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure

X_init : (k,d) np.ndarray
Initialization of the support locations (on k atoms) of the barycenter
b : (k,) np.ndarray
Initialization of the weights of the barycenter (non-negatives, sum to 1)
weights : (k,) np.ndarray
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)

numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing log parameter in the function.

would be nice to return the list of the displacement_square_norm along the iteration in a dictionnary if log=True (similar behavior as barycenter function above that retruns a log)

record log if True

Returns
-------
X : (k,d) np.ndarray
Support locations (on k atoms) of the barycenter

References
----------

.. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.

.. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.

"""

iter_count = 0

d = X_init.shape[1]
k = b.size
N = len(measures_locations)

if not weights:
weights = np.ones((N,)) / N

X = X_init

displacement_square_norm = stopThr + 1.

while (displacement_square_norm > stopThr and iter_count < numItermax):

T_sum = np.zeros((k, d))

for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):

M_i = ot.dist(X, measure_locations_i)
T_i = ot.emd(b, measure_weights_i, M_i)
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)

displacement_square_norm = np.sum(np.square(X - T_sum))
X = T_sum

if verbose:
print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)

iter_count += 1

return X
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add log if log=True