Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):

if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)

result_code_string = check_result(result_code)
if log:
log = {}
Expand Down Expand Up @@ -389,7 +389,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if log or return_matrix:
def f(b):
bsel = b != 0

G, cost, u, v, result_code = emd_c(a, b, M, numItermax)

if center_dual:
Expand Down Expand Up @@ -435,26 +435,36 @@ def f(b):

def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100,
stopThr=1e-7, verbose=False, log=None):
"""
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
r"""
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally:

.. math::
\min_X \sum_{i=1}^N w_i W_2^2(b, X, a_i, X_i)

where :

- :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
- the :math:`a_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i`
- the :math:`X_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations
- :math:`b \in \mathbb{R}^{k}` is the desired weights vector of the barycenter

The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:

- we do not optimize over the weights
- we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.

Parameters
----------
measures_locations : list of (k_i,d) numpy.ndarray
measures_locations : list of N (k_i,d) numpy.ndarray
The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
measures_weights : list of (k_i,) numpy.ndarray
measures_weights : list of N (k_i,) numpy.ndarray
Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure

X_init : (k,d) np.ndarray
Initialization of the support locations (on k atoms) of the barycenter
b : (k,) np.ndarray
Initialization of the weights of the barycenter (non-negatives, sum to 1)
weights : (k,) np.ndarray
weights : (N,) np.ndarray
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)

numItermax : int, optional
Expand Down