Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ POT provides the following generic OT solvers (links to examples):
* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37]
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
* Weak OT solver between empirical distributions [39]
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale).
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
Expand Down Expand Up @@ -301,3 +302,5 @@ Conference on Machine Learning, PMLR 119:4692-4701, 2020

[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021.

[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.
8 changes: 5 additions & 3 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

#### New features

- Better list of related examples in quick start guide with `minigallery` (PR #334)
- Better list of related examples in quick start guide with `minigallery` (PR #334).
- Add optional log-domain Sinkhorn implementation in WDA to support smaller values
of the regularization parameter (PR #336)
- Backend implementation for `ot.lp.free_support_barycenter` (PR #340)
of the regularization parameter (PR #336).
- Backend implementation for `ot.lp.free_support_barycenter` (PR #340).
- Add weak OT solver + example (PR #341).


#### Closed issues

Expand Down
1 change: 1 addition & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ API and modules
unbalanced
partial
sliced
weak

.. autosummary::
:toctree: ../modules/generated/
Expand Down
98 changes: 98 additions & 0 deletions examples/others/plot_WeakOT_VS_OT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
"""
====================================================
Weak Optimal Transport VS exact Optimal Transport
====================================================

Illustration of 2D optimal transport between distributions that are weighted
sum of diracs. The OT matrix is plotted with the samples.

"""

# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 4

import numpy as np
import matplotlib.pylab as pl
import ot
import ot.plot

##############################################################################
# Generate data an plot it
# ------------------------

#%% parameters and data generation

n = 50 # nb samples

mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])

mu_t = np.array([4, 4])
cov_t = np.array([[1, -.8], [-.8, 1]])

xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)

a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples

# loss matrix
M = ot.dist(xs, xt)
M /= M.max()

#%% plot samples

pl.figure(1)
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('Source and target distributions')

pl.figure(2)
pl.imshow(M, interpolation='nearest')
pl.title('Cost matrix M')


##############################################################################
# Compute Weak OT and exact OT solutions
# --------------------------------------

#%% EMD

G0 = ot.emd(a, b, M)

#%% Weak OT

Gweak = ot.weak_optimal_transport(xs, xt, a, b)


##############################################################################
# Plot weak OT and exact OT solutions
# --------------------------------------

pl.figure(3, (8, 5))

pl.subplot(1, 2, 1)
pl.imshow(G0, interpolation='nearest')
pl.title('OT matrix')

pl.subplot(1, 2, 2)
pl.imshow(Gweak, interpolation='nearest')
pl.title('Weak OT matrix')

pl.figure(4, (8, 5))

pl.subplot(1, 2, 1)
ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1])
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.title('OT matrix with samples')

pl.subplot(1, 2, 2)
ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[.5, .5, 1])
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.title('Weak OT matrix with samples')
5 changes: 2 additions & 3 deletions examples/plot_OT_2D_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

# loss matrix
M = ot.dist(xs, xt)
M /= M.max()

##############################################################################
# Plot data
Expand Down Expand Up @@ -87,7 +86,7 @@
#%% sinkhorn

# reg term
lambd = 1e-3
lambd = 1e-1

Gs = ot.sinkhorn(a, b, M, lambd)

Expand All @@ -112,7 +111,7 @@
#%% sinkhorn

# reg term
lambd = 1e-3
lambd = 1e-1

Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd)

Expand Down
5 changes: 3 additions & 2 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from . import partial
from . import backend
from . import regpath
from . import weak

# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
Expand All @@ -46,7 +47,7 @@
from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)

from .weak import weak_optimal_transport
# utils functions
from .utils import dist, unif, tic, toc, toq

Expand All @@ -59,5 +60,5 @@
'sinkhorn_unbalanced', 'barycenter_unbalanced',
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
'max_sliced_wasserstein_distance',
'max_sliced_wasserstein_distance', 'weak_optimal_transport',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']
16 changes: 16 additions & 0 deletions ot/gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
- :math:`\mathbf{q}`: distribution in the target space
- `L`: loss function to account for the misfit between the similarity matrices

.. note:: This function is backend-compatible and will work on arrays
from all compatible backends. But the algorithm uses the C++ CPU backend
which can lead to copy overhead on GPU arrays.

Parameters
----------
C1 : array-like, shape (ns, ns)
Expand Down Expand Up @@ -436,6 +440,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.

.. note:: This function is backend-compatible and will work on arrays
from all compatible backends. But the algorithm uses the C++ CPU backend
which can lead to copy overhead on GPU arrays.

Parameters
----------
C1 : array-like, shape (ns, ns)
Expand Down Expand Up @@ -545,6 +553,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
- `L` is a loss function to account for the misfit between the similarity matrices

.. note:: This function is backend-compatible and will work on arrays
from all compatible backends. But the algorithm uses the C++ CPU backend
which can lead to copy overhead on GPU arrays.

The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`

Parameters
Expand Down Expand Up @@ -645,6 +657,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
The algorithm used for solving the problem is conditional gradient as
discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`

.. note:: This function is backend-compatible and will work on arrays
from all compatible backends. But the algorithm uses the C++ CPU backend
which can lead to copy overhead on GPU arrays.

Note that when using backends, this loss function is differentiable wrt the
marices and weights for quadratic loss using the gradients from [38]_.

Expand Down
9 changes: 7 additions & 2 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from ..utils import parmap
from ..backend import get_backend



__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
'emd_1d', 'emd2_1d', 'wasserstein_1d']

Expand Down Expand Up @@ -220,7 +222,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
format

.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
from all compatible backends. But the algorithm uses the C++ CPU backend
which can lead to copy overhead on GPU arrays.

Uses the algorithm proposed in :ref:`[1] <references-emd>`.

Expand Down Expand Up @@ -358,7 +361,8 @@ def emd2(a, b, M, processes=1,
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights

.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
from all compatible backends. But the algorithm uses the C++ CPU backend
which can lead to copy overhead on GPU arrays.

Uses the algorithm proposed in :ref:`[1] <references-emd2>`.

Expand Down Expand Up @@ -622,3 +626,4 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
return X, log_dict
else:
return X

1 change: 0 additions & 1 deletion ot/lp/cvx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import scipy as sp
import scipy.sparse as sps


try:
import cvxopt
from cvxopt import solvers, matrix, spmatrix
Expand Down
12 changes: 9 additions & 3 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,27 @@ def proj_simplex(v, z=1):
return w


def unif(n):
def unif(n, type_as=None):
r"""
Return a uniform histogram of length `n` (simplex).

Parameters
----------
n : int
number of bins in the histogram
type_as : array_like
array of the same type of the expected output (numpy/pytorch/jax)

Returns
-------
h : np.array (`n`,)
h : array_like (`n`,)
histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}`
"""
return np.ones((n,)) / n
if type_as is None:
return np.ones((n,)) / n
else:
nx = get_backend(type_as)
return nx.ones((n,)) / n


def clean_zeros(a, b, M):
Expand Down
Loading