-
Notifications
You must be signed in to change notification settings - Fork 529
Free support barycenters #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
6492e95
3f23fa1
98ce4cc
e39f04a
2c7b980
67ddb92
4671279
08e5c0a
af57d90
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
==================================================== | ||
2D Wasserstein barycenters of distributions | ||
==================================================== | ||
|
||
Illustration of 2D Wasserstein barycenters if discributions that are weighted | ||
sum of diracs. | ||
|
||
""" | ||
|
||
# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp> | ||
# | ||
# License: MIT License | ||
|
||
import numpy as np | ||
import matplotlib.pylab as pl | ||
import ot.plot | ||
|
||
############################################################################## | ||
# Generate data | ||
# ------------- | ||
#%% parameters and data generation | ||
N = 3 | ||
d = 2 | ||
measures_locations = [] | ||
measures_weights = [] | ||
|
||
for i in range(N): | ||
|
||
n = np.random.randint(low=1, high=20) # nb samples | ||
|
||
mu = np.random.normal(0., 4., (d,)) | ||
|
||
A = np.random.rand(d, d) | ||
cov = np.dot(A, A.transpose()) | ||
|
||
x_i = ot.datasets.make_2D_samples_gauss(n, mu, cov) | ||
b_i = np.random.uniform(0., 1., (n,)) | ||
b_i = b_i / np.sum(b_i) | ||
|
||
measures_locations.append(x_i) | ||
measures_weights.append(b_i) | ||
|
||
|
||
############################################################################## | ||
# Compute free support barycenter | ||
# ------------- | ||
|
||
k = 10 | ||
X_init = np.random.normal(0., 1., (k, d)) | ||
b = np.ones((k,)) / k | ||
|
||
X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
you should import the function in ot.lp |
||
|
||
|
||
############################################################################## | ||
# Plot data | ||
# --------- | ||
|
||
#%% plot samples | ||
|
||
pl.figure(1) | ||
for (x_i, b_i) in zip(measures_locations, measures_weights): | ||
color = np.random.randint(low=1, high=10 * N) | ||
pl.scatter(x_i[:, 0], x_i[:, 1], s=b * 1000, label='input measure') | ||
pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter') | ||
pl.title('Data measures and their barycenter') | ||
pl.legend(loc=0) | ||
pl.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
import numpy as np | ||
import scipy as sp | ||
import scipy.sparse as sps | ||
import ot | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you shouldn't import pot inside a module. something with relative path like from .__init__ import emd is far better since it imports the emd function from |
||
|
||
try: | ||
import cvxopt | ||
|
@@ -26,7 +27,7 @@ def scipy_sparse_to_spmatrix(A): | |
|
||
|
||
def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'): | ||
"""Compute the entropic regularized wasserstein barycenter of distributions A | ||
"""Compute the Wasserstein barycenter of distributions A | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch ! |
||
|
||
The function solves the following optimization problem [16]: | ||
|
||
|
@@ -144,3 +145,83 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po | |
return b, sol | ||
else: | ||
return b | ||
|
||
|
||
def free_support_barycenter(measures_locations, measures_weights, X_init, b, weights=None, numItermax=100, stopThr=1e-6, verbose=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also do b=None if the weights are supposed uniform (needs test an initialization in the function) |
||
""" | ||
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) | ||
|
||
The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms. | ||
This problem is considered in [1] (Algorithm 2). There are two differences with the following codes: | ||
- we do not optimize over the weights | ||
- we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. | ||
|
||
Parameters | ||
---------- | ||
data_positions : list of (k_i,d) np.ndarray | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. names in the documentation different from the code : data_positions vs measures_locations |
||
The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list) | ||
data_weights : list of (k_i,) np.ndarray | ||
Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure | ||
|
||
X_init : (k,d) np.ndarray | ||
Initialization of the support locations (on k atoms) of the barycenter | ||
b : (k,) np.ndarray | ||
Initialization of the weights of the barycenter (non-negatives, sum to 1) | ||
weights : (k,) np.ndarray | ||
Initialization of the coefficients of the barycenter (non-negatives, sum to 1) | ||
|
||
numItermax : int, optional | ||
Max number of iterations | ||
stopThr : float, optional | ||
Stop threshol on error (>0) | ||
verbose : bool, optional | ||
Print information along iterations | ||
log : bool, optional | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing log parameter in the function. would be nice to return the list of the displacement_square_norm along the iteration in a dictionnary if log=True (similar behavior as barycenter function above that retruns a log) |
||
record log if True | ||
|
||
Returns | ||
------- | ||
X : (k,d) np.ndarray | ||
Support locations (on k atoms) of the barycenter | ||
|
||
References | ||
---------- | ||
|
||
.. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. | ||
|
||
.. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. | ||
|
||
""" | ||
|
||
iter_count = 0 | ||
|
||
d = X_init.shape[1] | ||
k = b.size | ||
N = len(measures_locations) | ||
|
||
if not weights: | ||
weights = np.ones((N,)) / N | ||
|
||
X = X_init | ||
|
||
displacement_square_norm = stopThr + 1. | ||
|
||
while (displacement_square_norm > stopThr and iter_count < numItermax): | ||
|
||
T_sum = np.zeros((k, d)) | ||
|
||
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): | ||
|
||
M_i = ot.dist(X, measure_locations_i) | ||
T_i = ot.emd(b, measure_weights_i, M_i) | ||
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) | ||
|
||
displacement_square_norm = np.sum(np.square(X - T_sum)) | ||
X = T_sum | ||
|
||
if verbose: | ||
print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm) | ||
|
||
iter_count += 1 | ||
|
||
return X | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add log if log=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
----
needs to have the proper length for good documentation generation.