Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ It provides the following solvers:
* Gromov-Wasserstein distances and barycenters ([13] and regularized [12])
* Stochastic Optimization for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
* Non regularized free support Wasserstein barycenters [20].
* Unbalanced OT with KL relaxation distance and barycenter [10, 25].

Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.

Expand Down Expand Up @@ -165,6 +166,7 @@ The contributors to this library are:
* [Kilian Fatras](https://kilianfatras.github.io/)
* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
* [Vayer Titouan](https://tvayer.github.io/)
* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT)

This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):

Expand Down Expand Up @@ -236,3 +238,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018

[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML).

[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2019). [Learning with a Wasserstein Loss](http://cbcl.mit.edu/wasserstein/) Advances in Neural Information Processing Systems (NIPS).
8 changes: 7 additions & 1 deletion docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ ot.da

.. automodule:: ot.da
:members:

ot.gpu
--------

Expand Down Expand Up @@ -80,3 +80,9 @@ ot.stochastic

.. automodule:: ot.stochastic
:members:

ot.unbalanced
-------------

.. automodule:: ot.unbalanced
:members:
76 changes: 76 additions & 0 deletions examples/plot_UOT_1D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
"""
===============================
1D Unbalanced optimal transport
===============================

This example illustrates the computation of Unbalanced Optimal transport
using a Kullback-Leibler relaxation.
"""

# Author: Hicham Janati <hicham.janati@inria.fr>
#
# License: MIT License

import numpy as np
import matplotlib.pylab as pl
import ot
import ot.plot
from ot.datasets import make_1D_gauss as gauss

##############################################################################
# Generate data
# -------------


#%% parameters

n = 100 # nb bins

# bin positions
x = np.arange(n, dtype=np.float64)

# Gaussian distributions
a = gauss(n, m=20, s=5) # m= mean, s= std
b = gauss(n, m=60, s=10)

# make distributions unbalanced
b *= 5.

# loss matrix
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
M /= M.max()


##############################################################################
# Plot distributions and loss matrix
# ----------------------------------

#%% plot the distributions

pl.figure(1, figsize=(6.4, 3))
pl.plot(x, a, 'b', label='Source distribution')
pl.plot(x, b, 'r', label='Target distribution')
pl.legend()

# plot distributions and loss matrix

pl.figure(2, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')


##############################################################################
# Solve Unbalanced Sinkhorn
# --------------


# Sinkhorn

epsilon = 0.1 # entropy parameter
alpha = 1. # Unbalanced KL relaxation parameter
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)

pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')

pl.show()
164 changes: 164 additions & 0 deletions examples/plot_UOT_barycenter_1D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# -*- coding: utf-8 -*-
"""
===========================================================
1D Wasserstein barycenter demo for Unbalanced distributions
===========================================================

This example illustrates the computation of regularized Wassersyein Barycenter
as proposed in [10] for Unbalanced inputs.


[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.

"""

# Author: Hicham Janati <hicham.janati@inria.fr>
#
# License: MIT License

import numpy as np
import matplotlib.pylab as pl
import ot
# necessary for 3d plot even if not used
from mpl_toolkits.mplot3d import Axes3D # noqa
from matplotlib.collections import PolyCollection

##############################################################################
# Generate data
# -------------

# parameters

n = 100 # nb bins

# bin positions
x = np.arange(n, dtype=np.float64)

# Gaussian distributions
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)

# make unbalanced dists
a2 *= 3.

# creating matrix A containing all distributions
A = np.vstack((a1, a2)).T
n_distributions = A.shape[1]

# loss matrix + normalization
M = ot.utils.dist0(n)
M /= M.max()

##############################################################################
# Plot data
# ---------

# plot the distributions

pl.figure(1, figsize=(6.4, 3))
for i in range(n_distributions):
pl.plot(x, A[:, i])
pl.title('Distributions')
pl.tight_layout()

##############################################################################
# Barycenter computation
# ----------------------

# non weighted barycenter computation

weight = 0.5 # 0<=weight<=1
weights = np.array([1 - weight, weight])

# l2bary
bary_l2 = A.dot(weights)

# wasserstein
reg = 1e-3
alpha = 1.

bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)

pl.figure(2)
pl.clf()
pl.subplot(2, 1, 1)
for i in range(n_distributions):
pl.plot(x, A[:, i])
pl.title('Distributions')

pl.subplot(2, 1, 2)
pl.plot(x, bary_l2, 'r', label='l2')
pl.plot(x, bary_wass, 'g', label='Wasserstein')
pl.legend()
pl.title('Barycenters')
pl.tight_layout()

##############################################################################
# Barycentric interpolation
# -------------------------

# barycenter interpolation

n_weight = 11
weight_list = np.linspace(0, 1, n_weight)


B_l2 = np.zeros((n, n_weight))

B_wass = np.copy(B_l2)

for i in range(0, n_weight):
weight = weight_list[i]
weights = np.array([1 - weight, weight])
B_l2[:, i] = A.dot(weights)
B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)


# plot interpolation

pl.figure(3)

cmap = pl.cm.get_cmap('viridis')
verts = []
zs = weight_list
for i, z in enumerate(zs):
ys = B_l2[:, i]
verts.append(list(zip(x, ys)))

ax = pl.gcf().gca(projection='3d')

poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
poly.set_alpha(0.7)
ax.add_collection3d(poly, zs=zs, zdir='y')
ax.set_xlabel('x')
ax.set_xlim3d(0, n)
ax.set_ylabel(r'$\alpha$')
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
pl.title('Barycenter interpolation with l2')
pl.tight_layout()

pl.figure(4)
cmap = pl.cm.get_cmap('viridis')
verts = []
zs = weight_list
for i, z in enumerate(zs):
ys = B_wass[:, i]
verts.append(list(zip(x, ys)))

ax = pl.gcf().gca(projection='3d')

poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
poly.set_alpha(0.7)
ax.add_collection3d(poly, zs=zs, zdir='y')
ax.set_xlabel('x')
ax.set_xlim3d(0, n)
ax.set_ylabel(r'$\alpha$')
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
pl.title('Barycenter interpolation with Wasserstein')
pl.tight_layout()

pl.show()
5 changes: 4 additions & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
from . import gromov
from . import smooth
from . import stochastic
from . import unbalanced

# OT functions
from .lp import emd, emd2
from .bregman import sinkhorn, sinkhorn2, barycenter
from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced
from .da import sinkhorn_lpl1_mm

# utils functions
Expand All @@ -33,4 +35,5 @@

__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', "barycenter_unbalanced"]
2 changes: 1 addition & 1 deletion ot/bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def sink():

b = np.asarray(b, dtype=np.float64)
if len(b.shape) < 2:
b = b.reshape((-1, 1))
b = b[:, None]

return sink()

Expand Down
Loading