diff --git a/README.md b/README.md index f622c5aab..dc7c5dfaa 100644 --- a/README.md +++ b/README.md @@ -334,3 +334,6 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [59] Taylor A. B. (2017). [Convex interpolation and performance estimation of first-order methods for convex optimization.](https://dial.uclouvain.be/pr/boreal/object/boreal%3A182881/datastream/PDF_01/view) PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017. +[60] Feydy, J., Roussillon, P., Trouvé, A., & Gori, P. (2019). [Fast and scalable optimal transport for brain tractograms](https://arxiv.org/pdf/2107.02010.pdf). In Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part III 22 (pp. 636-644). Springer International Publishing. + +[61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. (2021). [Kernel operations on the gpu, with autodiff, without memory overflows](https://www.jmlr.org/papers/volume22/20-275/20-275.pdf). The Journal of Machine Learning Research, 22(1), 3457-3462. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 9a7c94c3f..349c56214 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -17,6 +17,7 @@ + Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559) + Further upgraded unbalanced OT solvers for more flexibility and future use (PR #551) + New API function `ot.solve_sample` for solving OT problems from empirical samples (PR #563) ++ Wrapper for `geomloss`` solver on empirical samples (PR #571) + Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578) + Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578) diff --git a/ot/bregman/__init__.py b/ot/bregman/__init__.py index 230982e9c..0bcb4214d 100644 --- a/ot/bregman/__init__.py +++ b/ot/bregman/__init__.py @@ -39,6 +39,8 @@ from ._dictionary import (unmix) +from ._geomloss import (empirical_sinkhorn2_geomloss, geomloss) + __all__ = ['geometricBar', 'geometricMean', 'projR', 'projC', 'sinkhorn', 'sinkhorn2', 'sinkhorn_knopp', 'sinkhorn_log', @@ -46,8 +48,8 @@ 'barycenter', 'barycenter_sinkhorn', 'free_support_sinkhorn_barycenter', 'barycenter_stabilized', 'barycenter_debiased', 'jcpot_barycenter', 'convolutional_barycenter2d', 'convolutional_barycenter2d_debiased', - 'empirical_sinkhorn', 'empirical_sinkhorn2', - 'empirical_sinkhorn_divergence', + 'empirical_sinkhorn', 'empirical_sinkhorn2', 'empirical_sinkhorn2_geomloss' + 'empirical_sinkhorn_divergence', 'geomloss', 'screenkhorn', 'unmix' ] diff --git a/ot/bregman/_geomloss.py b/ot/bregman/_geomloss.py new file mode 100644 index 000000000..594d3b8e0 --- /dev/null +++ b/ot/bregman/_geomloss.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +""" +Wrapper functions for geomloss +""" + +# Author: Remi Flamary +# +# License: MIT License + +import numpy as np +try: + import geomloss + from geomloss import SamplesLoss + import torch + from torch.autograd import grad + from ..utils import get_backend, LazyTensor, dist +except ImportError: + geomloss = False + + +def get_sinkhorn_geomloss_lazytensor(X_a, X_b, f, g, a, b, metric='sqeuclidean', blur=0.1, nx=None): + """ Get a LazyTensor of sinkhorn solution T = exp((f+g^T-C)/reg)*(ab^T) + + Parameters + ---------- + X_a : array-like, shape (n_samples_a, dim) + samples in the source domain + X_torch: array-like, shape (n_samples_b, dim) + samples in the target domain + f : array-like, shape (n_samples_a,) + First dual potentials (log space) + g : array-like, shape (n_samples_b,) + Second dual potentials (log space) + metric : str, default='sqeuclidean' + Metric used for the cost matrix computation + blur : float, default=1e-1 + blur term (blur=sqrt(reg)) >0 + nx : Backend(), default=None + Numerical backend used + + + Returns + ------- + T : LazyTensor + Lowrank tensor T = exp((f+g^T-C)/reg)*(ab^T) + """ + + if nx is None: + nx = get_backend(X_a, X_b, f, g) + + shape = (X_a.shape[0], X_b.shape[0]) + + def func(i, j, X_a, X_b, f, g, a, b, metric, blur): + if metric == 'sqeuclidean': + C = dist(X_a[i], X_b[j], metric=metric) / 2 + else: + C = dist(X_a[i], X_b[j], metric=metric) + return nx.exp((f[i, None] + g[None, j] - C) / (blur**2)) * (a[i, None] * b[None, j]) + + T = LazyTensor(shape, func, X_a=X_a, X_b=X_b, f=f, g=g, a=a, b=b, metric=metric, blur=blur) + + return T + + +def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', scaling=0.95, + verbose=False, debias=False, log=False, backend='auto'): + r""" Solve the entropic regularization optimal transport problem with geomloss + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + + where : + + - :math:`C` is the cost matrix such that :math:`C_{i,j}=d(x_i^s,x_j^t)` and + :math:`d` is a metric. + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j}\gamma_{i,j}\log(\gamma_{i,j})-\gamma_{i,j}+1` + - :math:`a` and :math:`b` are source and target weights (sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in and computed in log space for + better stability and epsilon-scaling. The solution is computed ina lzy way + using the Geomloss [60] and the KeOps library [61]. + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + a : array-like, shape (n_samples_a,), default=None + samples weights in the source domain + b : array-like, shape (n_samples_b,), default=None + samples weights in the target domain + metric : str, default='sqeuclidean' + Metric used for the cost matrix computation Only acepted values are + 'sqeuclidean' and 'euclidean'. + scaling : float, default=0.95 + Scaling parameter used for epsilon scaling. Value close to one promote + precision while value close to zero promote speed. + verbose : bool, default=False + Print information + debias : bool, default=False + Use the debiased version of Sinkhorn algorithm [12]_. + log : bool, default=False + Return log dictionary containing all computed objects + backend : str, default='auto' + Numerical backend for geomloss. Only 'auto' and 'tensorized' 'online' + and 'multiscale' are accepted values. + + Returns + ------- + value : float + OT value + log : dict + Log dictionary return only if log==True in parameters + + References + ---------- + + .. [60] Feydy, J., Roussillon, P., Trouvé, A., & Gori, P. (2019). [Fast + and scalable optimal transport for brain tractograms. In Medical Image + Computing and Computer Assisted Intervention–MICCAI 2019: 22nd + International Conference, Shenzhen, China, October 13–17, 2019, + Proceedings, Part III 22 (pp. 636-644). Springer International + Publishing. + + .. [61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. + (2021). Kernel operations on the gpu, with autodiff, without memory + overflows. The Journal of Machine Learning Research, 22(1), 3457-3462. + + """ + + if geomloss: + + nx = get_backend(X_s, X_t, a, b) + + if nx.__name__ not in ['torch', 'numpy']: + raise ValueError('geomloss only support torch or numpy backend') + + if a is None: + a = nx.ones(X_s.shape[0], type_as=X_s) / X_s.shape[0] + if b is None: + b = nx.ones(X_t.shape[0], type_as=X_t) / X_t.shape[0] + + if nx.__name__ == 'numpy': + X_s_torch = torch.tensor(X_s) + X_t_torch = torch.tensor(X_t) + + a_torch = torch.tensor(a) + b_torch = torch.tensor(b) + + else: + X_s_torch = X_s + X_t_torch = X_t + + a_torch = a + b_torch = b + + # after that we are all in torch + + # set blur value and p + if metric == 'sqeuclidean': + p = 2 + blur = np.sqrt(reg / 2) # because geomloss divides cost by two + elif metric == 'euclidean': + p = 1 + blur = np.sqrt(reg) + else: + raise ValueError('geomloss only supports sqeuclidean and euclidean metrics') + + # force gradients for computing dual + a_torch.requires_grad = True + b_torch.requires_grad = True + + loss = SamplesLoss(loss='sinkhorn', p=p, blur=blur, backend=backend, debias=debias, scaling=scaling, verbose=verbose) + + # compute value + value = loss(a_torch, X_s_torch, b_torch, X_t_torch) # linear + entropic/KL reg? + + # get dual potentials + f, g = grad(value, [a_torch, b_torch]) + + if metric == 'sqeuclidean': + value *= 2 # because geomloss divides cost by two + + if nx.__name__ == 'numpy': + f = f.cpu().detach().numpy() + g = g.cpu().detach().numpy() + value = value.cpu().detach().numpy() + + if log: + log = {} + log['f'] = f + log['g'] = g + log['value'] = value + + log['lazy_plan'] = get_sinkhorn_geomloss_lazytensor(X_s, X_t, f, g, a, b, metric=metric, blur=blur, nx=nx) + + return value, log + + else: + return value + + else: + raise ImportError('geomloss not installed') diff --git a/ot/solvers.py b/ot/solvers.py index aed7e8ffe..b65f1149d 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -11,7 +11,7 @@ from .lp import emd2, wasserstein_1d from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced -from .bregman import sinkhorn_log, empirical_sinkhorn2 +from .bregman import sinkhorn_log, empirical_sinkhorn2, empirical_sinkhorn2_geomloss from .partial import partial_wasserstein_lagrange from .smooth import smooth_ot_dual from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2, @@ -23,6 +23,8 @@ from .gaussian import empirical_bures_wasserstein_distance from .factored import factored_optimal_transport +lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale'] + def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None, @@ -865,7 +867,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, - unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, + unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95, potentials_init=None, X_init=None, tol=None, verbose=False): r"""Solve the discrete optimal transport problem using the samples in the source and target domains. @@ -922,6 +924,10 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t Maximum number of iteration, by default None (default values in each solvers) plan_init : array_like, shape (dim_a, dim_b), optional Initialization of the OT plan for iterative methods, by default None + rank : int, optional + Rank of the OT matrix for lazy solers (method='factored'), by default 100 + scaling : float, optional + Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95 potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional Initialization of the OT dual potentials for iterative methods, by default None tol : _type_, optional @@ -939,6 +945,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t - res.potentials : OT dual potentials - res.value : Optimal value of the optimization problem - res.value_linear : Linear OT loss with the optimal OT plan + - res.lazy_plan : Lazy OT plan (when ``lazy=True`` or lazy method) See :any:`OTResult` for more information. @@ -1148,7 +1155,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t """ - if method is not None and method.lower() in ['1d', 'gaussian', 'lowrank', 'factored']: + if method is not None and method.lower() in lst_method_lazy: lazy0 = lazy lazy = True @@ -1221,6 +1228,28 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t if not lazy0: # store plan if not lazy plan = lazy_plan[:] + elif method.startswith('geomloss'): # Geomloss solver for entropi OT + + split_method = method.split('_') + if len(split_method) == 2: + backend = split_method[1] + else: + if lazy0 is None: + backend = 'auto' + elif lazy0: + backend = 'online' + else: + backend = 'tensorized' + + value, log = empirical_sinkhorn2_geomloss(X_a, X_b, reg=reg, a=a, b=b, metric=metric, log=True, verbose=verbose, scaling=scaling, backend=backend) + + lazy_plan = log['lazy_plan'] + if not lazy0: # store plan if not lazy + plan = lazy_plan[:] + + # return scaled potentials (to be consistent with other solvers) + potentials = (log['f'] / (lazy_plan.blur**2), log['g'] / (lazy_plan.blur**2)) + elif reg is None or reg == 0: # exact OT if unbalanced is None: # balanced EMD solver not available for lazy diff --git a/ot/utils.py b/ot/utils.py index f64c2fea6..cb29b21c9 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -527,7 +527,7 @@ def reduce_lazytensor(a, func, axis=None, nx=None, batch_size=100): """ if nx is None: - nx = get_backend(a[0]) + nx = get_backend(a[0:1]) if axis is None: res = 0.0 diff --git a/requirements.txt b/requirements.txt index 6ac25eb3c..6af50f127 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,6 @@ jaxlib tensorflow pytest torch_geometric -cvxpy \ No newline at end of file +cvxpy +geomloss +pykeops \ No newline at end of file diff --git a/test/test_bregman.py b/test/test_bregman.py index 67257f899..1a92a1037 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -15,6 +15,7 @@ import ot from ot.backend import tf, torch +from ot.bregman import geomloss @pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) @@ -1057,6 +1058,40 @@ def test_empirical_sinkhorn(nx): np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) +@pytest.mark.skipif(not geomloss, reason="pytorch not installed") +@pytest.skip_backend('tf') +@pytest.skip_backend("cupy") +@pytest.skip_backend("jax") +@pytest.mark.parametrize("metric", ["sqeuclidean", "euclidean"]) +def test_geomloss_solver(nx, metric): + # test sinkhorn + n = 10 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) + + G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric=metric)) + + value, log = ot.bregman.empirical_sinkhorn2_geomloss(X_sb, X_tb, 1, metric=metric, log=True) + G_geomloss = nx.to_numpy(log['lazy_plan'][:]) + + print(value) + + # call with log = False + ot.bregman.empirical_sinkhorn2_geomloss(X_sb, X_tb, 1, metric=metric) + + # check equality of plans + np.testing.assert_allclose(G_sqe, G_geomloss, atol=1e-03) # metric sqeuclidian + + # check error on wrong metric + with pytest.raises(ValueError): + ot.bregman.empirical_sinkhorn2_geomloss(X_sb, X_tb, 1, metric='wrong_metric') + + def test_lazy_empirical_sinkhorn(nx): # test sinkhorn n = 10 diff --git a/test/test_solvers.py b/test/test_solvers.py index c6e1a3770..bf07b7af8 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -8,9 +8,10 @@ import itertools import numpy as np import pytest +import sys import ot - +from ot.bregman import geomloss lst_reg = [None, 1] lst_reg_type = ['KL', 'entropy', 'L2'] @@ -348,6 +349,48 @@ def test_solve_sample_lazy(nx): np.testing.assert_allclose(sol0.plan, sol.lazy_plan[:], rtol=1e-5, atol=1e-5) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") +@pytest.mark.skipif(not geomloss, reason="pytorch not installed") +@pytest.skip_backend('tf') +@pytest.skip_backend("cupy") +@pytest.skip_backend("jax") +@pytest.mark.parametrize("metric", ["sqeuclidean", "euclidean"]) +def test_solve_sample_geomloss(nx, metric): + # test solve_sample when is_Lazy = False + n_samples_s = 13 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + xb, yb, ab, bb = nx.from_numpy(x, y, a, b) + + sol0 = ot.solve_sample(xb, yb, ab, bb, reg=1) + + # solve signe weights + sol = ot.solve_sample(xb, yb, ab, bb, reg=1, method='geomloss') + assert_allclose_sol(sol0, sol) + + sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=False, method='geomloss') + assert_allclose_sol(sol0, sol) + + sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=True, method='geomloss_tensorized') + np.testing.assert_allclose(nx.to_numpy(sol1.lazy_plan[:]), nx.to_numpy(sol.lazy_plan[:]), rtol=1e-5, atol=1e-5) + + sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=True, method='geomloss_online') + np.testing.assert_allclose(nx.to_numpy(sol1.lazy_plan[:]), nx.to_numpy(sol.lazy_plan[:]), rtol=1e-5, atol=1e-5) + + sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=True, method='geomloss_multiscale') + np.testing.assert_allclose(nx.to_numpy(sol1.lazy_plan[:]), nx.to_numpy(sol.lazy_plan[:]), rtol=1e-5, atol=1e-5) + + sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=True, method='geomloss') + np.testing.assert_allclose(nx.to_numpy(sol1.lazy_plan[:]), nx.to_numpy(sol.lazy_plan[:]), rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize("method_params", lst_method_params_solve_sample) def test_solve_sample_methods(nx, method_params):