Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ The contributors to this library are:
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning)
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)

## Acknowledgments

Expand Down
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,4 +288,8 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer

[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR.

[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors)
[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors)

[42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021.

[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
6 changes: 6 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Releases

## 0.8.3dev

#### New features

- Added Generalized Wasserstein Barycenter solver + example (PR #372)


## 0.8.2

Expand Down
6 changes: 3 additions & 3 deletions examples/backends/plot_sliced_wass_grad_flow_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@

# %%
# Animate trajectories of the gradient flow along iteration
# -------------------------------------------------------
# ---------------------------------------------------------

pl.figure(3, (8, 4))

Expand All @@ -122,7 +122,7 @@ def _update_plot(i):

# %%
# Compute the Sliced Wasserstein Barycenter
#
# -----------------------------------------
x1_torch = torch.tensor(x1).to(device=device)
x3_torch = torch.tensor(x3).to(device=device)
xbinit = np.random.randn(500, 2) * 10 + 16
Expand Down Expand Up @@ -169,7 +169,7 @@ def _update_plot(i):

# %%
# Animate trajectories of the barycenter along gradient descent
# -------------------------------------------------------
# -------------------------------------------------------------

pl.figure(5, (8, 4))

Expand Down
152 changes: 152 additions & 0 deletions examples/barycenters/plot_generalized_free_support_barycenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# -*- coding: utf-8 -*-
"""
=======================================
Generalized Wasserstein Barycenter Demo
=======================================

This example illustrates the computation of Generalized Wasserstein Barycenter
as proposed in [42].


[42] Delon, J., Gozlan, N., and Saint-Dizier, A..
Generalized Wasserstein barycenters between probability measures living on different subspaces.
arXiv preprint arXiv:2105.09755, 2021.

"""

# Author: Eloi Tanguy <eloi.tanguy@polytechnique.edu>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pylab as pl
import ot
import matplotlib.animation as animation

########################
# Generate and plot data
# ----------------------

# Input measures
sub_sample_factor = 8
I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]
I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]
I3 = pl.imread('../../data/heart.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]

sz = I1.shape[0]
UU, VV = np.meshgrid(np.arange(sz), np.arange(sz))

# Input measure locations in their respective 2D spaces
X_list = [np.stack((UU[im == 0], VV[im == 0]), 1) * 1.0 for im in [I1, I2, I3]]

# Input measure weights
a_list = [ot.unif(x.shape[0]) for x in X_list]

# Projections 3D -> 2D
P1 = np.array([[1, 0, 0], [0, 1, 0]])
P2 = np.array([[0, 1, 0], [0, 0, 1]])
P3 = np.array([[1, 0, 0], [0, 0, 1]])
P_list = [P1, P2, P3]

# Barycenter weights
weights = np.array([1 / 3, 1 / 3, 1 / 3])

# Number of barycenter points to compute
n_samples_bary = 150

# Send the input measures into 3D space for visualisation
X_visu = [Xi @ Pi for (Xi, Pi) in zip(X_list, P_list)]

# Plot the input data
fig = plt.figure(figsize=(3, 3))
axis = fig.add_subplot(1, 1, 1, projection="3d")
for Xi in X_visu:
axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
axis.view_init(azim=45)
axis.set_xticks([])
axis.set_yticks([])
axis.set_zticks([])
plt.show()

#################################
# Barycenter computation and plot
# -------------------------------

Y = ot.lp.generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary)
fig = plt.figure(figsize=(3, 3))

axis = fig.add_subplot(1, 1, 1, projection="3d")
for Xi in X_visu:
axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
axis.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
axis.view_init(azim=45)
axis.set_xticks([])
axis.set_yticks([])
axis.set_zticks([])
plt.show()


#############################
# Plotting projection matches
# ---------------------------

fig = plt.figure(figsize=(9, 3))

ax = fig.add_subplot(1, 3, 1, projection='3d')
for Xi in X_visu:
ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
ax.view_init(elev=0, azim=0)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

ax = fig.add_subplot(1, 3, 2, projection='3d')
for Xi in X_visu:
ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
ax.view_init(elev=0, azim=90)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

ax = fig.add_subplot(1, 3, 3, projection='3d')
for Xi in X_visu:
ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
ax.view_init(elev=90, azim=0)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

plt.tight_layout()
plt.show()

##############################################
# Rotation animation
# --------------------------------------------

fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(1, 1, 1, projection="3d")


def _init():
for Xi in X_visu:
ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6)
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6)
ax.view_init(elev=0, azim=0)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
return fig,


def _update_plot(i):
ax.view_init(elev=i, azim=4 * i)
return fig,


ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=90, interval=50, blit=True, repeat_delay=2000)
2 changes: 1 addition & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
# utils functions
from .utils import dist, unif, tic, toc, toq

__version__ = "0.8.2"
__version__ = "0.8.3dev"

__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
Expand Down
Loading