From 3bc2ceab4e9718e0ae5a907c12c4db65f9f41f98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 28 Jan 2022 17:39:43 +0100 Subject: [PATCH 1/9] add info in release file --- README.md | 3 ++ RELEASES.md | 5 +- docs/source/all.rst | 1 + ot/__init__.py | 5 +- ot/lp/__init__.py | 3 ++ ot/lp/cvx.py | 1 - ot/weak.py | 123 ++++++++++++++++++++++++++++++++++++++++++++ test/test_ot.py | 2 +- test/test_weak.py | 54 +++++++++++++++++++ 9 files changed, 191 insertions(+), 6 deletions(-) create mode 100644 ot/weak.py create mode 100644 test/test_weak.py diff --git a/README.md b/README.md index 17fbe81f7..a7627dfd6 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ POT provides the following generic OT solvers (links to examples): * Sinkhorn divergence [23] and entropic regularization OT from empirical data. * Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37] * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. +* Weak OT solver between empirical distributions [39] * Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale). * [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] @@ -301,3 +302,5 @@ Conference on Machine Learning, PMLR 119:4692-4701, 2020 [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. + +[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index a5fcbe15c..14f222594 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,9 +4,10 @@ #### New features -- Better list of related examples in quick start guide with `minigallery` (PR #334) +- Better list of related examples in quick start guide with `minigallery` (PR #334). - Add optional log-domain Sinkhorn implementation in WDA to support smaller values - of the regularization parameter (PR #336) + of the regularization parameter (PR #336). +- Add weak OT solver (PR #341) #### Closed issues diff --git a/docs/source/all.rst b/docs/source/all.rst index 7f85a9168..76d2ff5a6 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -28,6 +28,7 @@ API and modules unbalanced partial sliced + weak .. autosummary:: :toctree: ../modules/generated/ diff --git a/ot/__init__.py b/ot/__init__.py index 1ea740332..7253318b9 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -36,6 +36,7 @@ from . import partial from . import backend from . import regpath +from . import weak # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -46,7 +47,7 @@ from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance from .gromov import (gromov_wasserstein, gromov_wasserstein2, gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) - +from .weak import weak_optimal_transport # utils functions from .utils import dist, unif, tic, toc, toq @@ -59,5 +60,5 @@ 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', - 'max_sliced_wasserstein_distance', + 'max_sliced_wasserstein_distance', 'weak_optimal_transport', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 5da897d0b..e126fe23f 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -26,6 +26,8 @@ from ..utils import parmap from ..backend import get_backend + + __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] @@ -620,3 +622,4 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X, log_dict else: return X + diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 869d450bb..fbf3c0ed0 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -11,7 +11,6 @@ import scipy as sp import scipy.sparse as sps - try: import cvxopt from cvxopt import solvers, matrix, spmatrix diff --git a/ot/weak.py b/ot/weak.py new file mode 100644 index 000000000..4dc73848e --- /dev/null +++ b/ot/weak.py @@ -0,0 +1,123 @@ +""" +Weak optimal ransport solvers +""" + +# Author: Remi Flamary +# +# License: MIT License + +from .backend import get_backend +from .optim import cg +import numpy as np + +__all__ = ['weak_optimal_transport'] + + +def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs): + r"""Solves the weak optimal transport problem betwen two empirical distributions + + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \|X_a-diag(1/a)\gammaX_b\|_F^2 + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`X_a` :math:`X_b` are the sample matrices. + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + Uses the conditional gradient algorithm to solve the problem proposed + in :ref:`[39] `. + + Parameters + ---------- + Xa : (ns,d) array-like, float + Source samples + Xb : (nt,d) array-like, float + Target samples + a : (ns,) array-like, float + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list)) + numItermax : int, optional + Max number of iterations + numItermaxEmd : int, optional + Max number of iterations for emd + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma: array-like, shape (ns, nt) + Optimal transportation matrix for the given + parameters + log: dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status + + + .. _references-weak: + References + ---------- + .. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). + Kantorovich duality for general transport costs and applications. + Journal of Functional Analysis, 273(11), 3327-3405. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + """ + + nx = get_backend(Xa, Xb) + + Xa2 = nx.to_numpy(Xa) + Xb2 = nx.to_numpy(Xb) + + if a is None: + a2 = nx.ones((Xa.shape[0]), type_as=Xa) / Xa.shape[0] + else: + a2 = nx.to_numpy(a) + if b is None: + b2 = nx.ones((Xb.shape[0]), type_as=Xb) / Xb.shape[0] + else: + b2 = nx.to_numpy(b) + + # init uniform + if G0 is None: + T0 = a2[:, None] * b2[None, :] + else: + T0 = nx.to_numpy(G0) + + # weak OT loss + def f(T): + return np.dot(a2, np.sum((Xa2 - np.dot(T, Xb2) / a2[:, None])**2, 1)) + + # weak OT gradient + def df(T): + return -2 * np.dot(Xa2 - np.dot(T, Xb2) / a2[:, None], Xb2.T) + + # solve with conditional gradient and return solution + if log: + res, log = cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs) + log['u'] = nx.from_numpy(log['u'], type_as=Xa) + log['v'] = nx.from_numpy(log['v'], type_as=Xa) + return nx.from_numpy(res, type_as=Xa), log + else: + return nx.from_numpy(cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa) diff --git a/test/test_ot.py b/test/test_ot.py index 53edf4f65..2dad55371 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -232,7 +232,7 @@ def test_emd2_multi(): # Gaussian distributions a = gauss(n, m=20, s=5) # m= mean, s= std - ls = np.arange(20, 500, 20) + ls = np.arange(20, 500, 100) nb = len(ls) b = np.zeros((n, nb)) for i in range(nb): diff --git a/test/test_weak.py b/test/test_weak.py new file mode 100644 index 000000000..d85fbe865 --- /dev/null +++ b/test/test_weak.py @@ -0,0 +1,54 @@ +"""Tests for main module ot.weak """ + +# Author: Remi Flamary +# +# License: MIT License + +import ot +import numpy as np + + +def test_weak_ot(): + # test weak ot solver and identity stationary point + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + G = ot.weak_optimal_transport(xs, xt, u, u) + + # check constraints + np.testing.assert_allclose(u, G.sum(1)) + np.testing.assert_allclose(u, G.sum(0)) + + # chaeck that identity is recovered + G = ot.weak_optimal_transport(xs, xs, u, u, G0=np.eye(n) / n) + + # check G is identity + np.testing.assert_allclose(G, np.eye(n) / n) + + # check constraints + np.testing.assert_allclose(u, G.sum(1)) + np.testing.assert_allclose(u, G.sum(0)) + + +def test_weak_ot_bakends(nx): + # test weak ot solver for different backends + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + G = ot.weak_optimal_transport(xs, xt, u, u) + + xs2 = nx.from_numpy(xs) + xt2 = nx.from_numpy(xt) + u2 = nx.from_numpy(u) + + G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2) + + np.testing.assert_allclose(nx.to_numpy(G2), G) From e83257a03918d2a3bd81fcfab97dc4533c8100a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 31 Jan 2022 14:50:09 +0100 Subject: [PATCH 2/9] update tests --- RELEASES.md | 2 +- test/test_bregman.py | 13 ++++++------- test/test_weak.py | 4 ++-- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 14f222594..fc72d38b4 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -7,7 +7,7 @@ - Better list of related examples in quick start guide with `minigallery` (PR #334). - Add optional log-domain Sinkhorn implementation in WDA to support smaller values of the regularization parameter (PR #336). -- Add weak OT solver (PR #341) +- Add weak OT solver + example (PR #341) #### Closed issues diff --git a/test/test_bregman.py b/test/test_bregman.py index 6e90aa472..716c3c6a4 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -60,7 +60,7 @@ def test_convergence_warning(method): ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) -def test_not_impemented_method(): +def test_not_implemented_method(): # test sinkhorn w = 10 n = w ** 2 @@ -635,7 +635,7 @@ def test_wasserstein_bary_2d(nx, method): with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) else: - bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True) bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) @@ -667,7 +667,7 @@ def test_wasserstein_bary_2d_debiased(nx, method): with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) else: - bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True) bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) @@ -941,13 +941,12 @@ def test_screenkhorn(nx): M_nx = nx.from_numpy(M, type_as=ab) # np sinkhorn - G_sink_np = ot.sinkhorn(a, b, M, 1e-03) + G_sink_np = ot.sinkhorn(a, b, M, 1e-1) # sinkhorn - G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03)) + G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True)) + G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) # check marginals - np.testing.assert_allclose(G_sink_np, G_sink) np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) diff --git a/test/test_weak.py b/test/test_weak.py index d85fbe865..c4c32789e 100644 --- a/test/test_weak.py +++ b/test/test_weak.py @@ -17,14 +17,14 @@ def test_weak_ot(): xt = rng.randn(n, 2) u = ot.utils.unif(n) - G = ot.weak_optimal_transport(xs, xt, u, u) + G, log = ot.weak_optimal_transport(xs, xt, u, u, log=True) # check constraints np.testing.assert_allclose(u, G.sum(1)) np.testing.assert_allclose(u, G.sum(0)) # chaeck that identity is recovered - G = ot.weak_optimal_transport(xs, xs, u, u, G0=np.eye(n) / n) + G = ot.weak_optimal_transport(xs, xs, G0=np.eye(n) / n) # check G is identity np.testing.assert_allclose(G, np.eye(n) / n) From 19a6590c07603e369f9ba7e9320c4c04caf111f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 31 Jan 2022 14:54:50 +0100 Subject: [PATCH 3/9] pep8 --- test/test_bregman.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 716c3c6a4..1419f9bcd 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -940,8 +940,6 @@ def test_screenkhorn(nx): bb = nx.from_numpy(b) M_nx = nx.from_numpy(M, type_as=ab) - # np sinkhorn - G_sink_np = ot.sinkhorn(a, b, M, 1e-1) # sinkhorn G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) # screenkhorn From a541dc1c97bac176d1e42a7f002912f43e6c0f2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 31 Jan 2022 15:07:27 +0100 Subject: [PATCH 4/9] add weak OT example --- examples/others/plot_WeakOT_VS_OT.py | 101 +++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 examples/others/plot_WeakOT_VS_OT.py diff --git a/examples/others/plot_WeakOT_VS_OT.py b/examples/others/plot_WeakOT_VS_OT.py new file mode 100644 index 000000000..0cbe116da --- /dev/null +++ b/examples/others/plot_WeakOT_VS_OT.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +Weak Optimal Transport VS exact Optimal Transport +==================================================== + +Illustration of 2D optimal transport between distributions that are weighted +sum of diracs. The OT matrix is plotted with the samples. + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot + +############################################################################## +# Generate data an plot it +# ------------------------ + +#%% parameters and data generation + +n = 50 # nb samples + +mu_s = np.array([0, 0]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +# loss matrix +M = ot.dist(xs, xt) +M /= M.max() + +#%% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + +pl.figure(2) +pl.imshow(M, interpolation='nearest') +pl.title('Cost matrix M') + + +############################################################################## +# Compute Weak OT and exact OT solutions +# -------------------------------------- + +#%% EMD + +G0 = ot.emd(a, b, M) + +#%% Weak OT + +Gweak = ot.weak_optimal_transport(xs,xt,a,b) + + +############################################################################## +# Plot weak OT and exact OT solutions +# -------------------------------------- + +pl.figure(2,(8,5)) + +pl.subplot(1,2,1) +pl.imshow(G0, interpolation='nearest') +pl.title('OT matrix') + +pl.subplot(1,2,2) +pl.imshow(Gweak, interpolation='nearest') +pl.title('Weak OT matrix') + +pl.figure(3,(5,5)) + +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('OT matrix with samples') + +pl.figure(4,(5,5)) + +ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Weak OT matrix with samples') + From e18ae69043096dbedc059239de5427bf1ad68ce7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 31 Jan 2022 15:10:59 +0100 Subject: [PATCH 5/9] update plot in doc --- examples/others/plot_WeakOT_VS_OT.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/examples/others/plot_WeakOT_VS_OT.py b/examples/others/plot_WeakOT_VS_OT.py index 0cbe116da..80bffdc79 100644 --- a/examples/others/plot_WeakOT_VS_OT.py +++ b/examples/others/plot_WeakOT_VS_OT.py @@ -9,11 +9,11 @@ """ -# Author: Remi Flamary +# Author: Remi Flamary # # License: MIT License -# sphinx_gallery_thumbnail_number = 4 +# sphinx_gallery_thumbnail_number = 3 import numpy as np import matplotlib.pylab as pl @@ -66,36 +66,33 @@ #%% Weak OT -Gweak = ot.weak_optimal_transport(xs,xt,a,b) +Gweak = ot.weak_optimal_transport(xs, xt, a, b) ############################################################################## # Plot weak OT and exact OT solutions # -------------------------------------- -pl.figure(2,(8,5)) +pl.figure(2, (8, 5)) -pl.subplot(1,2,1) +pl.subplot(1, 2, 1) pl.imshow(G0, interpolation='nearest') pl.title('OT matrix') -pl.subplot(1,2,2) +pl.subplot(1, 2, 2) pl.imshow(Gweak, interpolation='nearest') pl.title('Weak OT matrix') -pl.figure(3,(5,5)) +pl.figure(3, (8, 5)) +pl.subplot(1, 2, 1) ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1]) pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') -pl.legend(loc=0) pl.title('OT matrix with samples') -pl.figure(4,(5,5)) - +pl.subplot(1, 2, 2) ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[.5, .5, 1]) pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') -pl.legend(loc=0) pl.title('Weak OT matrix with samples') - From 38313374479e98e46ee7089b9138ee031e1aca1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 31 Jan 2022 15:19:03 +0100 Subject: [PATCH 6/9] correction ewample with empirical sinkhorn --- examples/plot_OT_2D_samples.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index af1bc127a..c3a7cd810 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -42,7 +42,6 @@ # loss matrix M = ot.dist(xs, xt) -M /= M.max() ############################################################################## # Plot data @@ -87,7 +86,7 @@ #%% sinkhorn # reg term -lambd = 1e-3 +lambd = 1e-1 Gs = ot.sinkhorn(a, b, M, lambd) @@ -112,7 +111,7 @@ #%% sinkhorn # reg term -lambd = 1e-3 +lambd = 1e-1 Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd) From 11fd0383ae647340ce83809c544a67e6f9a13c20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 31 Jan 2022 15:39:22 +0100 Subject: [PATCH 7/9] better thumbnail --- examples/others/plot_WeakOT_VS_OT.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/others/plot_WeakOT_VS_OT.py b/examples/others/plot_WeakOT_VS_OT.py index 80bffdc79..79993539b 100644 --- a/examples/others/plot_WeakOT_VS_OT.py +++ b/examples/others/plot_WeakOT_VS_OT.py @@ -13,7 +13,7 @@ # # License: MIT License -# sphinx_gallery_thumbnail_number = 3 +# sphinx_gallery_thumbnail_number = 4 import numpy as np import matplotlib.pylab as pl @@ -73,7 +73,7 @@ # Plot weak OT and exact OT solutions # -------------------------------------- -pl.figure(2, (8, 5)) +pl.figure(3, (8, 5)) pl.subplot(1, 2, 1) pl.imshow(G0, interpolation='nearest') @@ -83,7 +83,7 @@ pl.imshow(Gweak, interpolation='nearest') pl.title('Weak OT matrix') -pl.figure(3, (8, 5)) +pl.figure(4, (8, 5)) pl.subplot(1, 2, 1) ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1]) From e1ca42ff40cd0dc5fcd1e5b5377f4f3f210d3323 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 2 Feb 2022 08:30:44 +0100 Subject: [PATCH 8/9] comment from review --- examples/others/plot_WeakOT_VS_OT.py | 2 +- ot/utils.py | 12 +++++++++--- ot/weak.py | 8 ++++---- test/test_utils.py | 18 +++++++++++++++--- 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/examples/others/plot_WeakOT_VS_OT.py b/examples/others/plot_WeakOT_VS_OT.py index 79993539b..a29c8756e 100644 --- a/examples/others/plot_WeakOT_VS_OT.py +++ b/examples/others/plot_WeakOT_VS_OT.py @@ -37,7 +37,7 @@ xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) -a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples +a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples # loss matrix M = ot.dist(xs, xt) diff --git a/ot/utils.py b/ot/utils.py index e6c93c8bd..725ca00ae 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -116,7 +116,7 @@ def proj_simplex(v, z=1): return w -def unif(n): +def unif(n, type_as=None): r""" Return a uniform histogram of length `n` (simplex). @@ -124,13 +124,19 @@ def unif(n): ---------- n : int number of bins in the histogram + type_as : array_like + array of the same type of the expected output (numpy/pytorch/jax) Returns ------- - h : np.array (`n`,) + h : array_like (`n`,) histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}` """ - return np.ones((n,)) / n + if type_as is None: + return np.ones((n,)) / n + else: + nx = get_backend(type_as) + return nx.ones((n,)) / n def clean_zeros(a, b, M): diff --git a/ot/weak.py b/ot/weak.py index 4dc73848e..b2890376c 100644 --- a/ot/weak.py +++ b/ot/weak.py @@ -14,7 +14,7 @@ def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs): - r"""Solves the weak optimal transport problem betwen two empirical distributions + r"""Solves the weak optimal transport problem between two empirical distributions .. math:: @@ -91,11 +91,11 @@ def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0= Xb2 = nx.to_numpy(Xb) if a is None: - a2 = nx.ones((Xa.shape[0]), type_as=Xa) / Xa.shape[0] + a2 = np.ones((Xa.shape[0])) / Xa.shape[0] else: a2 = nx.to_numpy(a) if b is None: - b2 = nx.ones((Xb.shape[0]), type_as=Xb) / Xb.shape[0] + b2 = np.ones((Xb.shape[0])) / Xb.shape[0] else: b2 = nx.to_numpy(b) @@ -117,7 +117,7 @@ def df(T): if log: res, log = cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs) log['u'] = nx.from_numpy(log['u'], type_as=Xa) - log['v'] = nx.from_numpy(log['v'], type_as=Xa) + log['v'] = nx.from_numpy(log['v'], type_as=Xb) return nx.from_numpy(res, type_as=Xa), log else: return nx.from_numpy(cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa) diff --git a/test/test_utils.py b/test/test_utils.py index 8b23c224e..5ad167b35 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -62,12 +62,12 @@ def test_tic_toc(): import time ot.tic() - time.sleep(0.5) + time.sleep(0.1) t = ot.toc() t2 = ot.toq() # test timing - np.testing.assert_allclose(0.5, t, rtol=1e-1, atol=1e-1) + np.testing.assert_allclose(0.1, t, rtol=1e-1, atol=1e-1) # test toc vs toq np.testing.assert_allclose(t, t2, rtol=1e-1, atol=1e-1) @@ -94,10 +94,22 @@ def test_unif(): np.testing.assert_allclose(1, np.sum(u)) -def test_dist(): +def test_unif_backend(nx): n = 100 + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + u = ot.unif(n, type_as=tp) + + np.testing.assert_allclose(1, np.sum(nx.to_numpy(u)), atol=1e-6) + + +def test_dist(): + + n = 10 + rng = np.random.RandomState(0) x = rng.randn(n, 2) From effe32f62a5ad0ffe6e7fd9b8b9681bdd0a08d82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 2 Feb 2022 11:11:12 +0100 Subject: [PATCH 9/9] update documenation --- ot/gromov.py | 16 ++++++++++++++++ ot/lp/__init__.py | 6 ++++-- ot/weak.py | 3 ++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/ot/gromov.py b/ot/gromov.py index 65442606d..b7e794995 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -338,6 +338,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F - :math:`\mathbf{q}`: distribution in the target space - `L`: loss function to account for the misfit between the similarity matrices + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + Parameters ---------- C1 : array-like, shape (ns, ns) @@ -436,6 +440,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= Note that when using backends, this loss function is differentiable wrt the marices and weights for quadratic loss using the gradients from [38]_. + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + Parameters ---------- C1 : array-like, shape (ns, ns) @@ -545,6 +553,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) - `L` is a loss function to account for the misfit between the similarity matrices + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` Parameters @@ -645,6 +657,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + Note that when using backends, this loss function is differentiable wrt the marices and weights for quadratic loss using the gradients from [38]_. diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index debd1acbd..d9b6fa9d4 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -222,7 +222,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): format .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. Uses the algorithm proposed in :ref:`[1] `. @@ -360,7 +361,8 @@ def emd2(a, b, M, processes=1, - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. Uses the algorithm proposed in :ref:`[1] `. diff --git a/ot/weak.py b/ot/weak.py index b2890376c..f7d5b2339 100644 --- a/ot/weak.py +++ b/ot/weak.py @@ -33,7 +33,8 @@ def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0= .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. Uses the conditional gradient algorithm to solve the problem proposed in :ref:`[39] `.