Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
631c642
add debiased sinkhorn barycenter + make loops pythonic
Oct 26, 2021
598a7a7
add debiased arg in tests
Oct 26, 2021
c69da39
add 1d and 2d examples of debiased barycenters
Oct 26, 2021
7296690
fix doctest
Oct 26, 2021
3253d55
fix flake8
Oct 26, 2021
6b091cd
Merge branch 'master' into debiased_barycenter
rflamary Oct 27, 2021
4a65460
Merge branch 'master' into debiased_barycenter
rflamary Oct 27, 2021
6751b74
pep8 + make func private + add convergence warnings
Oct 27, 2021
c8d0e34
remove rel paths + add rng + pylab to pyplot
Oct 27, 2021
3a6d2a9
fix stopping criterion debiased
Oct 27, 2021
bfaef83
fix conflict with log sinkhorn
Oct 27, 2021
e2ac99e
pass alex
agramfort Oct 27, 2021
366ff62
change params with new API
Oct 29, 2021
5297495
add logdomain barycenters + separate debiased API
Oct 29, 2021
6a38c03
test new API
Oct 29, 2021
17e58e7
fix jax read-only ?
Oct 30, 2021
e656cc6
Merge branch 'master' into debiased_barycenter
rflamary Oct 30, 2021
62cd9c9
raise error for jax
Nov 1, 2021
13d3575
test catch jax error
Nov 1, 2021
d8ae66f
fix pytest catch error
Nov 1, 2021
39fcf83
fix relative path
Nov 1, 2021
b6cbc2f
fix flake8
Nov 1, 2021
d0a4084
fix docstrings + conflicts
Nov 2, 2021
98ed570
Merge branch 'master' of https://github.com/PythonOT/POT into debiase…
Nov 2, 2021
154e203
add warn arg everywhere
Nov 2, 2021
ace0e67
fix ref number
Nov 2, 2021
40ae0de
catch warnings in tests
Nov 2, 2021
ebd5f6a
add contrib to readme + change ref number
Nov 2, 2021
8c2e1f2
fix convolution example + gallery thumbnails
Nov 2, 2021
8d37431
increase coverage
Nov 2, 2021
6bd076b
fix flake
Nov 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ POT provides the following generic OT solvers (links to examples):
* [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7].
* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html).
* Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4].
* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37]
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale).
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12])
Expand Down Expand Up @@ -188,7 +189,7 @@ The contributors to this library are
* [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic solvers)
* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
* [Vayer Titouan](https://tvayer.github.io/) (Gromov-Wasserstein -, Fused-Gromov-Wasserstein)
* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT)
* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT, Debiased barycenters)
* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein)
* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT)
Expand Down Expand Up @@ -293,3 +294,6 @@ You can also post bug reports and feature requests in Github issues. Make sure t
(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
Machine Learning (pp. 4104-4113). PMLR.

[37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
Conference on Machine Learning, PMLR 119:4692-4701, 2020
63 changes: 23 additions & 40 deletions examples/barycenters/plot_barycenter_1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 4
# sphinx_gallery_thumbnail_number = 1

import numpy as np
import matplotlib.pylab as pl
import matplotlib.pyplot as plt
import ot
# necessary for 3d plot even if not used
from mpl_toolkits.mplot3d import Axes3D # noqa
Expand Down Expand Up @@ -50,18 +50,6 @@
M = ot.utils.dist0(n)
M /= M.max()

##############################################################################
# Plot data
# ---------

#%% plot the distributions

pl.figure(1, figsize=(6.4, 3))
for i in range(n_distributions):
pl.plot(x, A[:, i])
pl.title('Distributions')
pl.tight_layout()

##############################################################################
# Barycenter computation
# ----------------------
Expand All @@ -78,24 +66,20 @@
reg = 1e-3
bary_wass = ot.bregman.barycenter(A, M, reg, weights)

pl.figure(2)
pl.clf()
pl.subplot(2, 1, 1)
for i in range(n_distributions):
pl.plot(x, A[:, i])
pl.title('Distributions')
f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True, num=1)
ax1.plot(x, A, color="black")
ax1.set_title('Distributions')

pl.subplot(2, 1, 2)
pl.plot(x, bary_l2, 'r', label='l2')
pl.plot(x, bary_wass, 'g', label='Wasserstein')
pl.legend()
pl.title('Barycenters')
pl.tight_layout()
ax2.plot(x, bary_l2, 'r', label='l2')
ax2.plot(x, bary_wass, 'g', label='Wasserstein')
ax2.set_title('Barycenters')

plt.legend()
plt.show()

##############################################################################
# Barycentric interpolation
# -------------------------

#%% barycenter interpolation

n_alpha = 11
Expand All @@ -106,24 +90,23 @@

B_wass = np.copy(B_l2)

for i in range(0, n_alpha):
for i in range(n_alpha):
alpha = alpha_list[i]
weights = np.array([1 - alpha, alpha])
B_l2[:, i] = A.dot(weights)
B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)

#%% plot interpolation
plt.figure(2)

pl.figure(3)

cmap = pl.cm.get_cmap('viridis')
cmap = plt.cm.get_cmap('viridis')
verts = []
zs = alpha_list
for i, z in enumerate(zs):
ys = B_l2[:, i]
verts.append(list(zip(x, ys)))

ax = pl.gcf().gca(projection='3d')
ax = plt.gcf().gca(projection='3d')

poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
Expand All @@ -134,18 +117,18 @@
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
pl.title('Barycenter interpolation with l2')
pl.tight_layout()
plt.title('Barycenter interpolation with l2')
plt.tight_layout()

pl.figure(4)
cmap = pl.cm.get_cmap('viridis')
plt.figure(3)
cmap = plt.cm.get_cmap('viridis')
verts = []
zs = alpha_list
for i, z in enumerate(zs):
ys = B_wass[:, i]
verts.append(list(zip(x, ys)))

ax = pl.gcf().gca(projection='3d')
ax = plt.gcf().gca(projection='3d')

poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
Expand All @@ -156,7 +139,7 @@
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
pl.title('Barycenter interpolation with Wasserstein')
pl.tight_layout()
plt.title('Barycenter interpolation with Wasserstein')
plt.tight_layout()

pl.show()
plt.show()
2 changes: 1 addition & 1 deletion examples/barycenters/plot_barycenter_lp_vs_entropic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
=================================================================================
1D Wasserstein barycenter comparison between exact LP and entropic regularization
1D Wasserstein barycenter: exact LP vs entropic regularization
=================================================================================

This example illustrates the computation of regularized Wasserstein Barycenter
Expand Down
53 changes: 25 additions & 28 deletions examples/barycenters/plot_convolutional_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@
Convolutional Wasserstein Barycenter example
============================================

This example is designed to illustrate how the Convolutional Wasserstein Barycenter
function of POT works.
This example is designed to illustrate how the Convolutional Wasserstein
Barycenter function of POT works.
"""

# Author: Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License

import os
from pathlib import Path

import numpy as np
import pylab as pl
import matplotlib.pyplot as plt
import ot

##############################################################################
Expand All @@ -25,22 +26,19 @@
#
# The four distributions are constructed from 4 simple images

this_file = os.path.realpath('__file__')
data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')

f1 = 1 - pl.imread('../../data/redcross.png')[:, :, 2]
f2 = 1 - pl.imread('../../data/duck.png')[:, :, 2]
f3 = 1 - pl.imread('../../data/heart.png')[:, :, 2]
f4 = 1 - pl.imread('../../data/tooth.png')[:, :, 2]
f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2]
f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2]
f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]

A = []
f1 = f1 / np.sum(f1)
f2 = f2 / np.sum(f2)
f3 = f3 / np.sum(f3)
f4 = f4 / np.sum(f4)
A.append(f1)
A.append(f2)
A.append(f3)
A.append(f4)
A = np.array(A)
A = np.array([f1, f2, f3, f4])

nb_images = 5

Expand All @@ -57,14 +55,13 @@
# ----------------------------------------
#

pl.figure(figsize=(10, 10))
pl.title('Convolutional Wasserstein Barycenters in POT')
fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7))
plt.suptitle('Convolutional Wasserstein Barycenters in POT')
cm = 'Blues'
# regularization parameter
reg = 0.004
for i in range(nb_images):
for j in range(nb_images):
pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
tx = float(i) / (nb_images - 1)
ty = float(j) / (nb_images - 1)

Expand All @@ -74,19 +71,19 @@
weights = (1 - ty) * tmp1 + ty * tmp2

if i == 0 and j == 0:
pl.imshow(f1, cmap=cm)
pl.axis('off')
axes[i, j].imshow(f1, cmap=cm)
elif i == 0 and j == (nb_images - 1):
pl.imshow(f3, cmap=cm)
pl.axis('off')
axes[i, j].imshow(f3, cmap=cm)
elif i == (nb_images - 1) and j == 0:
pl.imshow(f2, cmap=cm)
pl.axis('off')
axes[i, j].imshow(f2, cmap=cm)
elif i == (nb_images - 1) and j == (nb_images - 1):
pl.imshow(f4, cmap=cm)
pl.axis('off')
axes[i, j].imshow(f4, cmap=cm)
else:
# call to barycenter computation
pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
pl.axis('off')
pl.show()
axes[i, j].imshow(
ot.bregman.convolutional_barycenter2d(A, reg, weights),
cmap=cm
)
axes[i, j].axis('off')
plt.tight_layout()
plt.show()
131 changes: 131 additions & 0 deletions examples/barycenters/plot_debiased_barycenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
"""
=================================
Debiased Sinkhorn barycenter demo
=================================

This example illustrates the computation of the debiased Sinkhorn barycenter
as proposed in [37]_.


.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th
International Conference on Machine Learning, PMLR 119:4692-4701, 2020
"""

# Author: Hicham Janati <hicham.janati100@gmail.com>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 3

import os
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

import ot
from ot.bregman import (barycenter, barycenter_debiased,
convolutional_barycenter2d,
convolutional_barycenter2d_debiased)

##############################################################################
# Debiased barycenter of 1D Gaussians
# ------------------------------------

#%% parameters

n = 100 # nb bins

# bin positions
x = np.arange(n, dtype=np.float64)

# Gaussian distributions
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)

# creating matrix A containing all distributions
A = np.vstack((a1, a2)).T
n_distributions = A.shape[1]

# loss matrix + normalization
M = ot.utils.dist0(n)
M /= M.max()

#%% barycenter computation

alpha = 0.2 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])

epsilons = [5e-3, 1e-2, 5e-2]


bars = [barycenter(A, M, reg, weights) for reg in epsilons]
bars_debiased = [barycenter_debiased(A, M, reg, weights) for reg in epsilons]
labels = ["Sinkhorn barycenter", "Debiased barycenter"]
colors = ["indianred", "gold"]

f, axes = plt.subplots(1, len(epsilons), tight_layout=True, sharey=True,
figsize=(12, 4), num=1)
for ax, eps, bar, bar_debiased in zip(axes, epsilons, bars, bars_debiased):
ax.plot(A[:, 0], color="k", ls="--", label="Input data", alpha=0.3)
ax.plot(A[:, 1], color="k", ls="--", alpha=0.3)
for data, label, color in zip([bar, bar_debiased], labels, colors):
ax.plot(data, color=color, label=label, lw=2)
ax.set_title(r"$\varepsilon = %.3f$" % eps)
plt.legend()
plt.show()


##############################################################################
# Debiased barycenter of 2D images
# ---------------------------------
this_file = os.path.realpath('__file__')
data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
f1 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
f2 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]

A = np.asarray([f1, f2]) + 1e-2
A /= A.sum(axis=(1, 2))[:, None, None]

##############################################################################
# Display the input images

fig, axes = plt.subplots(1, 2, figsize=(7, 4), num=2)
for ax, img in zip(axes, A):
ax.imshow(img, cmap="Greys")
ax.axis("off")
fig.tight_layout()
plt.show()


##############################################################################
# Barycenter computation and visualization
# ----------------------------------------
#

bars_sinkhorn, bars_debiased = [], []
epsilons = [5e-3, 7e-3, 1e-2]
for eps in epsilons:
bar = convolutional_barycenter2d(A, eps)
bar_debiased, log = convolutional_barycenter2d_debiased(A, eps, log=True)
bars_sinkhorn.append(bar)
bars_debiased.append(bar_debiased)

titles = ["Sinkhorn", "Debiased"]
all_bars = [bars_sinkhorn, bars_debiased]
fig, axes = plt.subplots(2, 3, figsize=(8, 6), num=3)
for jj, (method, ax_row, bars) in enumerate(zip(titles, axes, all_bars)):
for ii, (ax, img, eps) in enumerate(zip(ax_row, bars, epsilons)):
ax.imshow(img, cmap="Greys")
if jj == 0:
ax.set_title(r"$\varepsilon = %.3f$" % eps, fontsize=13)
ax.set_xticks([])
ax.set_yticks([])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
if ii == 0:
ax.set_ylabel(method, fontsize=15)
fig.tight_layout()
plt.show()
Loading