From f3ad957bf2bf07b3479bb8ebbc18d499fdeff508 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Tue, 16 Jul 2024 17:07:28 +0200 Subject: [PATCH 01/17] restructure unbalanced.py --- ot/__init__.py | 2 +- ot/solvers.py | 1386 ----------------- ot/unbalanced/Untitled.ipynb | 973 ++++++++++++ ot/unbalanced/__init__.py | 27 + ot/unbalanced/_lbfgs.py | 357 +++++ ot/unbalanced/_mm.py | 288 ++++ ot/{unbalanced.py => unbalanced/_sinkhorn.py} | 632 +------- test/unbalanced/__init__.py | 0 test/unbalanced/test_lbfgs.py | 126 ++ test/unbalanced/test_mm.py | 164 ++ .../test_sinkhorn.py} | 309 +--- 11 files changed, 2053 insertions(+), 2211 deletions(-) delete mode 100644 ot/solvers.py create mode 100644 ot/unbalanced/Untitled.ipynb create mode 100644 ot/unbalanced/__init__.py create mode 100644 ot/unbalanced/_lbfgs.py create mode 100644 ot/unbalanced/_mm.py rename ot/{unbalanced.py => unbalanced/_sinkhorn.py} (67%) create mode 100644 test/unbalanced/__init__.py create mode 100644 test/unbalanced/test_lbfgs.py create mode 100644 test/unbalanced/test_mm.py rename test/{test_unbalanced.py => unbalanced/test_sinkhorn.py} (55%) diff --git a/ot/__init__.py b/ot/__init__.py index e1b29ba53..9e4dd8fcc 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -72,5 +72,5 @@ 'factored_optimal_transport', 'solve', 'solve_gromov', 'solve_sample', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', - 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn', + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn', 'lowrank_gromov_wasserstein_samples'] diff --git a/ot/solvers.py b/ot/solvers.py deleted file mode 100644 index 95165ea11..000000000 --- a/ot/solvers.py +++ /dev/null @@ -1,1386 +0,0 @@ -# -*- coding: utf-8 -*- -""" -General OT solvers with unified API -""" - -# Author: Remi Flamary -# -# License: MIT License - -from .utils import OTResult, dist -from .lp import emd2, wasserstein_1d -from .backend import get_backend -from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced -from .bregman import sinkhorn_log, empirical_sinkhorn2, empirical_sinkhorn2_geomloss -from .partial import partial_wasserstein_lagrange -from .smooth import smooth_ot_dual -from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2, - entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2, - semirelaxed_gromov_wasserstein2, semirelaxed_fused_gromov_wasserstein2, - entropic_semirelaxed_fused_gromov_wasserstein2, - entropic_semirelaxed_gromov_wasserstein2) -from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2 -from .gaussian import empirical_bures_wasserstein_distance -from .factored import factored_optimal_transport -from .lowrank import lowrank_sinkhorn -from .optim import cg - -lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale'] - - -def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, - unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None, - potentials_init=None, tol=None, verbose=False, grad='autodiff'): - r"""Solve the discrete optimal transport problem and return :any:`OTResult` object - - The function solves the following general optimal transport problem - - .. math:: - \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + - \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + - \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) - - The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By - default ``reg=None`` and there is no regularization. The unbalanced marginal - penalization can be selected with `unbalanced` (:math:`\lambda_u`) and - `unbalanced_type`. By default ``unbalanced=None`` and the function - solves the exact optimal transport problem (respecting the marginals). - - Parameters - ---------- - M : array_like, shape (dim_a, dim_b) - Loss matrix - a : array-like, shape (dim_a,), optional - Samples weights in the source domain (default is uniform) - b : array-like, shape (dim_b,), optional - Samples weights in the source domain (default is uniform) - reg : float, optional - Regularization weight :math:`\lambda_r`, by default None (no reg., exact - OT) - reg_type : str, optional - Type of regularization :math:`R` either "KL", "L2", "entropy", - by default "KL". a tuple of functions can be provided for general - solver (see :any:`cg`). This is only used when ``reg!=None``. - unbalanced : float, optional - Unbalanced penalization weight :math:`\lambda_u`, by default None - (balanced OT) - unbalanced_type : str, optional - Type of unbalanced penalization function :math:`U` either "KL", "L2", - "TV", by default "KL". - method : str, optional - Method for solving the problem when multiple algorithms are available, - default None for automatic selection. - n_threads : int, optional - Number of OMP threads for exact OT solver, by default 1 - max_iter : int, optional - Maximum number of iterations, by default None (default values in each solvers) - plan_init : array_like, shape (dim_a, dim_b), optional - Initialization of the OT plan for iterative methods, by default None - potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional - Initialization of the OT dual potentials for iterative methods, by default None - tol : _type_, optional - Tolerance for solution precision, by default None (default values in each solvers) - verbose : bool, optional - Print information in the solver, by default False - grad : str, optional - Type of gradient computation, either or 'autodiff' or 'envelope' used only for - Sinkhorn solver. By default 'autodiff' provides gradients wrt all - outputs (`plan, value, value_linear`) but with important memory cost. - 'envelope' provides gradients only for `value` and and other outputs are - detached. This is useful for memory saving when only the value is needed. - - Returns - ------- - res : OTResult() - Result of the optimization problem. The information can be obtained as follows: - - - res.plan : OT plan :math:`\mathbf{T}` - - res.potentials : OT dual potentials - - res.value : Optimal value of the optimization problem - - res.value_linear : Linear OT loss with the optimal OT plan - - See :any:`OTResult` for more information. - - Notes - ----- - - The following methods are available for solving the OT problems: - - - **Classical exact OT problem [1]** (default parameters) : - - .. math:: - \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F - - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} - - \mathbf{T}^T \mathbf{1} = \mathbf{b} - - \mathbf{T} \geq 0 - - can be solved with the following code: - - .. code-block:: python - - res = ot.solve(M, a, b) - - - **Entropic regularized OT [2]** (when ``reg!=None``): - - .. math:: - \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) - - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} - - \mathbf{T}^T \mathbf{1} = \mathbf{b} - - \mathbf{T} \geq 0 - - can be solved with the following code: - - .. code-block:: python - - # default is ``"KL"`` regularization (``reg_type="KL"``) - res = ot.solve(M, a, b, reg=1.0) - # or for original Sinkhorn paper formulation [2] - res = ot.solve(M, a, b, reg=1.0, reg_type='entropy') - - # Use envelope theorem differentiation for memory saving - res = ot.solve(M, a, b, reg=1.0, grad='envelope') # M, a, b are torch tensors - res.value.backward() # only the value is differentiable - - Note that by default the Sinkhorn solver uses automatic differentiation to - compute the gradients of the values and plan. This can be changed with the - `grad` parameter. The `envelope` mode computes the gradients only - for the value and the other outputs are detached. This is useful for - memory saving when only the gradient of value is needed. - - - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): - - .. math:: - \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) - - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} - - \mathbf{T}^T \mathbf{1} = \mathbf{b} - - \mathbf{T} \geq 0 - - can be solved with the following code: - - .. code-block:: python - - res = ot.solve(M,a,b,reg=1.0,reg_type='L2') - - - **Unbalanced OT [41]** (when ``unbalanced!=None``): - - .. math:: - \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) - - can be solved with the following code: - - .. code-block:: python - - # default is ``"KL"`` - res = ot.solve(M,a,b,unbalanced=1.0) - # quadratic unbalanced OT - res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='L2') - # TV = partial OT - res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='TV') - - - - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): - - .. math:: - \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) - - can be solved with the following code: - - .. code-block:: python - - # default is ``"KL"`` for both - res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0) - # quadratic unbalanced OT with KL regularization - res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2') - # both quadratic - res = ot.solve(M,a,b,reg=1.0, reg_type='L2',unbalanced=1.0,unbalanced_type='L2') - - - .. _references-solve: - References - ---------- - - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. - (2011, December). Displacement interpolation using Lagrangian mass - transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. - 158). ACM. - - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation - of Optimal Transport, Advances in Neural Information Processing - Systems (NIPS) 26, 2013 - - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). - Scaling algorithms for unbalanced transport problems. - arXiv preprint arXiv:1607.05816. - - .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse - Optimal Transport. Proceedings of the Twenty-First International - Conference on Artificial Intelligence and Statistics (AISTATS). - - .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, - A., & Peyré, G. (2019, April). Interpolating between optimal transport - and MMD using Sinkhorn divergences. In The 22nd International Conference - on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. - - .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). - Unbalanced optimal transport through non-negative penalized - linear regression. NeurIPS. - - """ - - # detect backend - arr = [M] - if a is not None: - arr.append(a) - if b is not None: - arr.append(b) - nx = get_backend(*arr) - - # create uniform weights if not given - if a is None: - a = nx.ones(M.shape[0], type_as=M) / M.shape[0] - if b is None: - b = nx.ones(M.shape[1], type_as=M) / M.shape[1] - - # default values for solutions - potentials = None - value = None - value_linear = None - plan = None - status = None - - if reg is None or reg == 0: # exact OT - - if unbalanced is None: # Exact balanced OT - - # default values for EMD solver - if max_iter is None: - max_iter = 1000000 - - value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads) - - value = value_linear - potentials = (log['u'], log['v']) - plan = log['G'] - status = log["warning"] if log["warning"] is not None else 'Converged' - - elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT - - # default values for exact unbalanced OT - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-12 - - plan, log = mm_unbalanced(a, b, M, reg_m=unbalanced, - div=unbalanced_type.lower(), numItermax=max_iter, - stopThr=tol, log=True, - verbose=verbose, G0=plan_init) - - value_linear = log['cost'] - - if unbalanced_type.lower() == 'kl': - value = value_linear + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b)) - else: - err_a = nx.sum(plan, 1) - a - err_b = nx.sum(plan, 0) - b - value = value_linear + unbalanced * nx.sum(err_a**2) + unbalanced * nx.sum(err_b**2) - - elif unbalanced_type.lower() == 'tv': - - if max_iter is None: - max_iter = 1000000 - - plan, log = partial_wasserstein_lagrange(a, b, M, reg_m=unbalanced**2, log=True, numItermax=max_iter) - - value_linear = nx.sum(M * plan) - err_a = nx.sum(plan, 1) - a - err_b = nx.sum(plan, 0) - b - value = value_linear + nx.sqrt(unbalanced**2 / 2.0 * (nx.sum(nx.abs(err_a)) + - nx.sum(nx.abs(err_b)))) - - else: - raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type))) - - else: # regularized OT - - if unbalanced is None: # Balanced regularized OT - - if isinstance(reg_type, tuple): # general solver - - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - - plan, log = cg(a, b, M, reg=reg, f=reg_type[0], df=reg_type[1], numItermax=max_iter, stopThr=tol, log=True, verbose=verbose, G0=plan_init) - - value_linear = nx.sum(M * plan) - value = log['loss'][-1] - potentials = (log['u'], log['v']) - - elif reg_type.lower() in ['entropy', 'kl']: - - if grad == 'envelope': # if envelope then detach the input - M0, a0, b0 = M, a, b - M, a, b = nx.detach(M, a, b) - - # default values for sinkhorn - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - - plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter, - stopThr=tol, log=True, - verbose=verbose) - - value_linear = nx.sum(M * plan) - - if reg_type.lower() == 'entropy': - value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) - else: - value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) - - potentials = (log['log_u'], log['log_v']) - - if grad == 'envelope': # set the gradient at convergence - - value = nx.set_gradients(value, (M0, a0, b0), - (plan, reg * (potentials[0] - potentials[0].mean()), reg * (potentials[1] - potentials[1].mean()))) - - elif reg_type.lower() == 'l2': - - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - - plan, log = smooth_ot_dual(a, b, M, reg=reg, numItermax=max_iter, stopThr=tol, log=True, verbose=verbose) - - value_linear = nx.sum(M * plan) - value = value_linear + reg * nx.sum(plan**2) - potentials = (log['alpha'], log['beta']) - - else: - raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type))) - - else: # unbalanced AND regularized OT - - if not isinstance(reg_type, tuple) and reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl': - - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - - plan, log = sinkhorn_knopp_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, numItermax=max_iter, stopThr=tol, verbose=verbose, log=True) - - value_linear = nx.sum(M * plan) - - value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b)) - - potentials = (log['logu'], log['logv']) - - elif (isinstance(reg_type, tuple) or reg_type.lower() in ['kl', 'l2', 'entropy']) and unbalanced_type.lower() in ['kl', 'l2', 'tv']: - - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-12 - if isinstance(reg_type, str): - reg_type = reg_type.lower() - - plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type, regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True, G0=plan_init) - - value_linear = nx.sum(M * plan) - - value = log['loss'] - - else: - raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) - - res = OTResult(potentials=potentials, value=value, - value_linear=value_linear, plan=plan, status=status, backend=nx) - - return res - - -def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, - alpha=0.5, reg=None, - reg_type="entropy", unbalanced=None, unbalanced_type='KL', - n_threads=1, method=None, max_iter=None, plan_init=None, tol=None, - verbose=False): - r""" Solve the discrete (Fused) Gromov-Wasserstein and return :any:`OTResult` object - - The function solves the following optimization problem: - - .. math:: - \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) - - The regularization is selected with `reg` (:math:`\lambda_r`) and - `reg_type`. By default ``reg=None`` and there is no regularization. The - unbalanced marginal penalization can be selected with `unbalanced` - (:math:`\lambda_u`) and `unbalanced_type`. By default ``unbalanced=None`` - and the function solves the exact optimal transport problem (respecting the - marginals). - - Parameters - ---------- - Ca : array_like, shape (dim_a, dim_a) - Cost matrix in the source domain - Cb : array_like, shape (dim_b, dim_b) - Cost matrix in the target domain - M : array_like, shape (dim_a, dim_b), optional - Linear cost matrix for Fused Gromov-Wasserstein (default is None). - a : array-like, shape (dim_a,), optional - Samples weights in the source domain (default is uniform) - b : array-like, shape (dim_b,), optional - Samples weights in the source domain (default is uniform) - loss : str, optional - Type of loss function, either ``"L2"`` or ``"KL"``, by default ``"L2"`` - symmetric : bool, optional - Use symmetric version of the Gromov-Wasserstein problem, by default None - tests whether the matrices are symmetric or True/False to avoid the test. - reg : float, optional - Regularization weight :math:`\lambda_r`, by default None (no reg., exact - OT) - reg_type : str, optional - Type of regularization :math:`R`, by default "entropy" (only used when - ``reg!=None``) - alpha : float, optional - Weight the quadratic term (alpha*Gromov) and the linear term - ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for - Gromov problem (when M is not provided). By default ``alpha=None`` - corresponds to ``alpha=1`` for Gromov problem (``M==None``) and - ``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``) - unbalanced : float, optional - Unbalanced penalization weight :math:`\lambda_u`, by default None - (balanced OT), Not implemented yet - unbalanced_type : str, optional - Type of unbalanced penalization function :math:`U` either "KL", "semirelaxed", - "partial", by default "KL" but note that it is not implemented yet. - n_threads : int, optional - Number of OMP threads for exact OT solver, by default 1 - method : str, optional - Method for solving the problem when multiple algorithms are available, - default None for automatic selection. - max_iter : int, optional - Maximum number of iterations, by default None (default values in each - solvers) - plan_init : array_like, shape (dim_a, dim_b), optional - Initialization of the OT plan for iterative methods, by default None - tol : float, optional - Tolerance for solution precision, by default None (default values in - each solvers) - verbose : bool, optional - Print information in the solver, by default False - - Returns - ------- - res : OTResult() - Result of the optimization problem. The information can be obtained as follows: - - - res.plan : OT plan :math:`\mathbf{T}` - - res.potentials : OT dual potentials - - res.value : Optimal value of the optimization problem - - res.value_linear : Linear OT loss with the optimal OT plan - - res.value_quad : Quadratic (GW) part of the OT loss with the optimal OT plan - - See :any:`OTResult` for more information. - - Notes - ----- - The following methods are available for solving the Gromov-Wasserstein - problem: - - - **Classical Gromov-Wasserstein (GW) problem [3]** (default parameters): - - .. math:: - \min_{\mathbf{T}\geq 0} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} - - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} - - \mathbf{T}^T \mathbf{1} = \mathbf{b} - - \mathbf{T} \geq 0 - - can be solved with the following code: - - .. code-block:: python - - res = ot.solve_gromov(Ca, Cb) # uniform weights - res = ot.solve_gromov(Ca, Cb, a=a, b=b) # given weights - res = ot.solve_gromov(Ca, Cb, loss='KL') # KL loss - - plan = res.plan # GW plan - value = res.value # GW value - - - **Fused Gromov-Wasserstein (FGW) problem [24]** (when ``M!=None``): - - .. math:: - \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} - - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} - - \mathbf{T}^T \mathbf{1} = \mathbf{b} - - \mathbf{T} \geq 0 - - can be solved with the following code: - - .. code-block:: python - - res = ot.solve_gromov(Ca, Cb, M) # uniform weights, alpha=0.5 (default) - res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, alpha=0.1) # given weights and alpha - - plan = res.plan # FGW plan - loss_linear_term = res.value_linear # Wasserstein part of the loss - loss_quad_term = res.value_quad # Gromov part of the loss - loss = res.value # FGW value - - - **Regularized (Fused) Gromov-Wasserstein (GW) problem [12]** (when ``reg!=None``): - - .. math:: - \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + \lambda_r R(\mathbf{T}) - - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} - - \mathbf{T}^T \mathbf{1} = \mathbf{b} - - \mathbf{T} \geq 0 - - can be solved with the following code: - - .. code-block:: python - - res = ot.solve_gromov(Ca, Cb, reg=1.0) # GW entropy regularization (default) - res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, reg=10, alpha=0.1) # FGW with entropy - - plan = res.plan # FGW plan - loss_linear_term = res.value_linear # Wasserstein part of the loss - loss_quad_term = res.value_quad # Gromov part of the loss - loss = res.value # FGW value (including regularization) - - - **Semi-relaxed (Fused) Gromov-Wasserstein (GW) [48]** (when ``unbalanced='semirelaxed'``): - - .. math:: - \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} - - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} - - \mathbf{T} \geq 0 - - can be solved with the following code: - - .. code-block:: python - - res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed') # semirelaxed GW - res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed', reg=1) # entropic semirelaxed GW - res = ot.solve_gromov(Ca, Cb, M, unbalanced='semirelaxed', alpha=0.1) # semirelaxed FGW - - plan = res.plan # FGW plan - right_marginal = res.marginal_b # right marginal of the plan - - - **Partial (Fused) Gromov-Wasserstein (GW) problem [29]** (when ``unbalanced='partial'``): - - .. math:: - \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} - - s.t. \ \mathbf{T} \mathbf{1} \leq \mathbf{a} - - \mathbf{T}^T \mathbf{1} \leq \mathbf{b} - - \mathbf{T} \geq 0 - - \mathbf{1}^T\mathbf{T}\mathbf{1} = m - - can be solved with the following code: - - .. code-block:: python - - res = ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=0.8) # partial GW with m=0.8 - - - .. _references-solve-gromov: - References - ---------- - - .. [3] Mémoli, F. (2011). Gromov–Wasserstein distances and the metric - approach to object matching. Foundations of computational mathematics, - 11(4), 417-487. - - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), - Gromov-Wasserstein averaging of kernel and distance matrices - International Conference on Machine Learning (ICML). - - .. [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. - (2019). Optimal Transport for structured data with application on graphs - Proceedings of the 36th International Conference on Machine Learning - (ICML). - - .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, - Nicolas Courty (2022). Semi-relaxed Gromov-Wasserstein divergence and - applications on graphs. International Conference on Learning - Representations (ICLR), 2022. - - .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). Partial Optimal Transport - with Applications on Positive-Unlabeled Learning, Advances in Neural - Information Processing Systems (NeurIPS), 2020. - - """ - - # detect backend - nx = get_backend(Ca, Cb, M, a, b) - - # create uniform weights if not given - if a is None: - a = nx.ones(Ca.shape[0], type_as=Ca) / Ca.shape[0] - if b is None: - b = nx.ones(Cb.shape[1], type_as=Cb) / Cb.shape[1] - - # default values for solutions - potentials = None - value = None - value_linear = None - value_quad = None - plan = None - status = None - log = None - - loss_dict = {'l2': 'square_loss', 'kl': 'kl_loss'} - - if loss.lower() not in loss_dict.keys(): - raise (NotImplementedError('Not implemented GW loss="{}"'.format(loss))) - loss_fun = loss_dict[loss.lower()] - - if reg is None or reg == 0: # exact OT - - if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']: # Exact balanced OT - - if M is None or alpha == 1: # Gromov-Wasserstein problem - - # default values for solver - if max_iter is None: - max_iter = 10000 - if tol is None: - tol = 1e-9 - - value, log = gromov_wasserstein2(Ca, Cb, a, b, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) - - value_quad = value - if alpha == 1: # set to 0 for FGW with alpha=1 - value_linear = 0 - plan = log['T'] - potentials = (log['u'], log['v']) - - elif alpha == 0: # Wasserstein problem - - # default values for EMD solver - if max_iter is None: - max_iter = 1000000 - - value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads) - - value = value_linear - potentials = (log['u'], log['v']) - plan = log['G'] - status = log["warning"] if log["warning"] is not None else 'Converged' - value_quad = 0 - - else: # Fused Gromov-Wasserstein problem - - # default values for solver - if max_iter is None: - max_iter = 10000 - if tol is None: - tol = 1e-9 - - value, log = fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) - - value_linear = log['lin_loss'] - value_quad = log['quad_loss'] - plan = log['T'] - potentials = (log['u'], log['v']) - - elif unbalanced_type.lower() in ['semirelaxed']: # Semi-relaxed OT - - if M is None or alpha == 1: # Semi relaxed Gromov-Wasserstein problem - - # default values for solver - if max_iter is None: - max_iter = 10000 - if tol is None: - tol = 1e-9 - - value, log = semirelaxed_gromov_wasserstein2(Ca, Cb, a, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) - - value_quad = value - if alpha == 1: # set to 0 for FGW with alpha=1 - value_linear = 0 - plan = log['T'] - # potentials = (log['u'], log['v']) TODO - - else: # Semi relaxed Fused Gromov-Wasserstein problem - - # default values for solver - if max_iter is None: - max_iter = 10000 - if tol is None: - tol = 1e-9 - - value, log = semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) - - value_linear = log['lin_loss'] - value_quad = log['quad_loss'] - plan = log['T'] - # potentials = (log['u'], log['v']) TODO - - elif unbalanced_type.lower() in ['partial']: # Partial OT - - if M is None: # Partial Gromov-Wasserstein problem - - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise (ValueError('Partial GW mass given in reg is too large')) - if loss.lower() != 'l2': - raise (NotImplementedError('Partial GW only implemented with L2 loss')) - if symmetric is not None: - raise (NotImplementedError('Partial GW only implemented with symmetric=True')) - - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-7 - - value, log = partial_gromov_wasserstein2(Ca, Cb, a, b, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose) - - value_quad = value - plan = log['T'] - # potentials = (log['u'], log['v']) TODO - - else: # partial FGW - - raise (NotImplementedError('Partial FGW not implemented yet')) - - elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT - - raise (NotImplementedError('Unbalanced_type="{}"'.format(unbalanced_type))) - - else: - raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type))) - - else: # regularized OT - - if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']: # Balanced regularized OT - - if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Gromov-Wasserstein problem - - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - if method is None: - method = 'PGD' - - value_quad, log = entropic_gromov_wasserstein2(Ca, Cb, a, b, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) - - plan = log['T'] - value_linear = 0 - value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) - # potentials = (log['log_u'], log['log_v']) #TODO - - elif reg_type.lower() in ['entropy'] and M is not None and alpha == 0: # Entropic Wasserstein problem - - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - - plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter, - stopThr=tol, log=True, - verbose=verbose) - - value_linear = nx.sum(M * plan) - value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) - potentials = (log['log_u'], log['log_v']) - - elif reg_type.lower() in ['entropy'] and M is not None: # Entropic Fused Gromov-Wasserstein problem - - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - if method is None: - method = 'PGD' - - value_noreg, log = entropic_fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) - - value_linear = log['lin_loss'] - value_quad = log['quad_loss'] - plan = log['T'] - # potentials = (log['u'], log['v']) - value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) - - else: - raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type))) - - elif unbalanced_type.lower() in ['semirelaxed']: # Semi-relaxed OT - - if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Semi-relaxed Gromov-Wasserstein problem - - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - - value_quad, log = entropic_semirelaxed_gromov_wasserstein2(Ca, Cb, a, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) - - plan = log['T'] - value_linear = 0 - value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) - - else: # Entropic Semi-relaxed FGW problem - - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - - value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) - - value_linear = log['lin_loss'] - value_quad = log['quad_loss'] - plan = log['T'] - value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) - - elif unbalanced_type.lower() in ['partial']: # Partial OT - - if M is None: # Partial Gromov-Wasserstein problem - - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise (ValueError('Partial GW mass given in reg is too large')) - if loss.lower() != 'l2': - raise (NotImplementedError('Partial GW only implemented with L2 loss')) - if symmetric is not None: - raise (NotImplementedError('Partial GW only implemented with symmetric=True')) - - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-7 - - value_quad, log = entropic_partial_gromov_wasserstein2(Ca, Cb, a, b, reg=reg, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose) - - value_quad = value - plan = log['T'] - # potentials = (log['u'], log['v']) TODO - - else: # partial FGW - - raise (NotImplementedError('Partial entropic FGW not implemented yet')) - - else: # unbalanced AND regularized OT - - raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) - - res = OTResult(potentials=potentials, value=value, - value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx, log=log) - - return res - - -def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", - unbalanced=None, - unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95, - potentials_init=None, X_init=None, tol=None, verbose=False, - grad='autodiff'): - r"""Solve the discrete optimal transport problem using the samples in the source and target domains. - - The function solves the following general optimal transport problem - - .. math:: - \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + - \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + - \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) - - where the cost matrix :math:`\mathbf{M}` is computed from the samples in the - source and target domains such that :math:`M_{i,j} = d(x_i,y_j)` where - :math:`d` is a metric (by default the squared Euclidean distance). - - The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By - default ``reg=None`` and there is no regularization. The unbalanced marginal - penalization can be selected with `unbalanced` (:math:`\lambda_u`) and - `unbalanced_type`. By default ``unbalanced=None`` and the function - solves the exact optimal transport problem (respecting the marginals). - - Parameters - ---------- - X_s : array-like, shape (n_samples_a, dim) - samples in the source domain - X_t : array-like, shape (n_samples_b, dim) - samples in the target domain - a : array-like, shape (dim_a,), optional - Samples weights in the source domain (default is uniform) - b : array-like, shape (dim_b,), optional - Samples weights in the source domain (default is uniform) - reg : float, optional - Regularization weight :math:`\lambda_r`, by default None (no reg., exact - OT) - reg_type : str, optional - Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" - unbalanced : float, optional - Unbalanced penalization weight :math:`\lambda_u`, by default None - (balanced OT) - unbalanced_type : str, optional - Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" - lazy : bool, optional - Return :any:`OTResultlazy` object to reduce memory cost when True, by - default False - batch_size : int, optional - Batch size for lazy solver, by default None (default values in each - solvers) - method : str, optional - Method for solving the problem, this can be used to select the solver - for unbalanced problems (see :any:`ot.solve`), or to select a specific - large scale solver. - n_threads : int, optional - Number of OMP threads for exact OT solver, by default 1 - max_iter : int, optional - Maximum number of iteration, by default None (default values in each solvers) - plan_init : array_like, shape (dim_a, dim_b), optional - Initialization of the OT plan for iterative methods, by default None - rank : int, optional - Rank of the OT matrix for lazy solers (method='factored'), by default 100 - scaling : float, optional - Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95 - potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional - Initialization of the OT dual potentials for iterative methods, by default None - tol : _type_, optional - Tolerance for solution precision, by default None (default values in each solvers) - verbose : bool, optional - Print information in the solver, by default False - grad : str, optional - Type of gradient computation, either or 'autodiff' or 'envelope' used only for - Sinkhorn solver. By default 'autodiff' provides gradients wrt all - outputs (`plan, value, value_linear`) but with important memory cost. - 'envelope' provides gradients only for `value` and and other outputs are - detached. This is useful for memory saving when only the value is needed. - - Returns - ------- - - res : OTResult() - Result of the optimization problem. The information can be obtained as follows: - - - res.plan : OT plan :math:`\mathbf{T}` - - res.potentials : OT dual potentials - - res.value : Optimal value of the optimization problem - - res.value_linear : Linear OT loss with the optimal OT plan - - res.lazy_plan : Lazy OT plan (when ``lazy=True`` or lazy method) - - See :any:`OTResult` for more information. - - Notes - ----- - - The following methods are available for solving the OT problems: - - - **Classical exact OT problem [1]** (default parameters) : - - .. math:: - \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F - - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} - - \mathbf{T}^T \mathbf{1} = \mathbf{b} - - \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) - - - - can be solved with the following code: - - .. code-block:: python - - res = ot.solve_sample(xa, xb, a, b) - - # for uniform weights - res = ot.solve_sample(xa, xb) - - - **Entropic regularized OT [2]** (when ``reg!=None``): - - .. math:: - \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) - - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} - - \mathbf{T}^T \mathbf{1} = \mathbf{b} - - \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) - - can be solved with the following code: - - .. code-block:: python - - # default is ``"KL"`` regularization (``reg_type="KL"``) - res = ot.solve_sample(xa, xb, a, b, reg=1.0) - # or for original Sinkhorn paper formulation [2] - res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='entropy') - - # lazy solver of memory complexity O(n) - res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100) - # lazy OT plan - lazy_plan = res.lazy_plan - - # Use envelope theorem differentiation for memory saving - res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='envelope') - res.value.backward() # only the value is differentiable - - Note that by default the Sinkhorn solver uses automatic differentiation to - compute the gradients of the values and plan. This can be changed with the - `grad` parameter. The `envelope` mode computes the gradients only - for the value and the other outputs are detached. This is useful for - memory saving when only the gradient of value is needed. - - We also have a very efficient solver with compiled CPU/CUDA code using - geomloss/PyKeOps that can be used with the following code: - - .. code-block:: python - - # automatic solver - res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss') - - # force O(n) memory efficient solver - res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_online') - - # force pre-computed cost matrix - res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_tensorized') - - # use multiscale solver - res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_multiscale') - - # One can play with speed (small scaling factor) and precision (scaling close to 1) - res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss', scaling=0.5) - - - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): - - .. math:: - \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) - - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} - - \mathbf{T}^T \mathbf{1} = \mathbf{b} - - \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) - - can be solved with the following code: - - .. code-block:: python - - res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2') - - - **Unbalanced OT [41]** (when ``unbalanced!=None``): - - .. math:: - \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) - - with M_{i,j} = d(x_i,y_j) - - can be solved with the following code: - - .. code-block:: python - - # default is ``"KL"`` - res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0) - # quadratic unbalanced OT - res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='L2') - # TV = partial OT - res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='TV') - - - - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): - - .. math:: - \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) - - with M_{i,j} = d(x_i,y_j) - - can be solved with the following code: - - .. code-block:: python - - # default is ``"KL"`` for both - res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0) - # quadratic unbalanced OT with KL regularization - res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0,unbalanced_type='L2') - # both quadratic - res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2', - unbalanced=1.0, unbalanced_type='L2') - - - - **Factored OT [2]** (when ``method='factored'``): - - This method solve the following OT problem [40]_ - - .. math:: - \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) - - where $\mu$ is a uniform weighted empirical distribution of :math:`\mu_a` and :math:`\mu_b` are the empirical measures associated - to the samples in the source and target domains, and :math:`W_2` is the - Wasserstein distance. This problem is solved using exact OT solvers for - `reg=None` and the Sinkhorn solver for `reg!=None`. The solution provides - two transport plans that can be used to recover a low rank OT plan between - the two distributions. - - .. code-block:: python - - res = ot.solve_sample(xa, xb, method='factored', rank=10) - - # recover the lazy low rank plan - factored_solution_lazy = res.lazy_plan - - # recover the full low rank plan - factored_solution = factored_solution_lazy[:] - - - **Gaussian Bures-Wasserstein [2]** (when ``method='gaussian'``): - - This method computes the Gaussian Bures-Wasserstein distance between two - Gaussian distributions estimated from teh empirical distributions - - .. math:: - \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} - - where : - - .. math:: - \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) - - The covariances and means are estimated from the data. - - .. code-block:: python - - res = ot.solve_sample(xa, xb, method='gaussian') - - # recover the squared Gaussian Bures-Wasserstein distance - BW_dist = res.value - - - **Wasserstein 1d [1]** (when ``method='1D'``): - - This method computes the Wasserstein distance between two 1d distributions - estimated from the empirical distributions. For multivariate data the - distances are computed independently for each dimension. - - .. code-block:: python - - res = ot.solve_sample(xa, xb, method='1D') - - # recover the squared Wasserstein distances - W_dists = res.value - - - .. _references-solve-sample: - References - ---------- - - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. - (2011, December). Displacement interpolation using Lagrangian mass - transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. - 158). ACM. - - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation - of Optimal Transport, Advances in Neural Information Processing - Systems (NIPS) 26, 2013 - - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). - Scaling algorithms for unbalanced transport problems. - arXiv preprint arXiv:1607.05816. - - .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse - Optimal Transport. Proceedings of the Twenty-First International - Conference on Artificial Intelligence and Statistics (AISTATS). - - .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, - A., & Peyré, G. (2019, April). Interpolating between optimal transport - and MMD using Sinkhorn divergences. In The 22nd International Conference - on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. - - .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, - G., & Weed, J. (2019, April). Statistical optimal transport via factored - couplings. In The 22nd International Conference on Artificial - Intelligence and Statistics (pp. 2454-2465). PMLR. - - .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). - Unbalanced optimal transport through non-negative penalized - linear regression. NeurIPS. - - .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). - Low-rank Sinkhorn Factorization. In International Conference on - Machine Learning. - - - """ - - if method is not None and method.lower() in lst_method_lazy: - lazy0 = lazy - lazy = True - - if not lazy: # default non lazy solver calls ot.solve - - # compute cost matrix M and use solve function - M = dist(X_a, X_b, metric) - - res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose, grad) - - return res - - else: - - # Detect backend - nx = get_backend(X_a, X_b, a, b) - - # default values for solutions - potentials = None - value = None - value_linear = None - plan = None - lazy_plan = None - status = None - log = None - - method = method.lower() if method is not None else '' - - if method == '1d': # Wasserstein 1d (parallel on all dimensions) - if metric == 'sqeuclidean': - p = 2 - elif metric in ['euclidean', 'cityblock']: - p = 1 - else: - raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) - - value = wasserstein_1d(X_a, X_b, a, b, p=p) - value_linear = value - - elif method == 'gaussian': # Gaussian Bures-Wasserstein - - if not metric.lower() in ['sqeuclidean']: - raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) - - if reg is None: - reg = 1e-6 - - value, log = empirical_bures_wasserstein_distance(X_a, X_b, reg=reg, log=True) - value = value**2 # return the value (squared bures distance) - value_linear = value # return the value - - elif method == 'factored': # Factored OT - - if not metric.lower() in ['sqeuclidean']: - raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) - - if max_iter is None: - max_iter = 100 - if tol is None: - tol = 1e-7 - if reg is None: - reg = 0 - - Q, R, X, log = factored_optimal_transport(X_a, X_b, reg=reg, r=rank, log=True, stopThr=tol, numItermax=max_iter, verbose=verbose) - log['X'] = X - - value_linear = log['costa'] + log['costb'] - value = value_linear # TODO add reg term - lazy_plan = log['lazy_plan'] - if not lazy0: # store plan if not lazy - plan = lazy_plan[:] - - elif method == "lowrank": - - if not metric.lower() in ['sqeuclidean']: - raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) - - if max_iter is None: - max_iter = 2000 - if tol is None: - tol = 1e-7 - if reg is None: - reg = 0 - - Q, R, g, log = lowrank_sinkhorn(X_a, X_b, rank=rank, reg=reg, a=a, b=b, numItermax=max_iter, stopThr=tol, log=True) - value = log['value'] - value_linear = log['value_linear'] - lazy_plan = log['lazy_plan'] - if not lazy0: # store plan if not lazy - plan = lazy_plan[:] - - elif method.startswith('geomloss'): # Geomloss solver for entropic OT - - split_method = method.split('_') - if len(split_method) == 2: - backend = split_method[1] - else: - if lazy0 is None: - backend = 'auto' - elif lazy0: - backend = 'online' - else: - backend = 'tensorized' - - value, log = empirical_sinkhorn2_geomloss(X_a, X_b, reg=reg, a=a, b=b, metric=metric, log=True, verbose=verbose, scaling=scaling, backend=backend) - - lazy_plan = log['lazy_plan'] - if not lazy0: # store plan if not lazy - plan = lazy_plan[:] - - # return scaled potentials (to be consistent with other solvers) - potentials = (log['f'] / (lazy_plan.blur**2), log['g'] / (lazy_plan.blur**2)) - - elif reg is None or reg == 0: # exact OT - - if unbalanced is None: # balanced EMD solver not available for lazy - raise (NotImplementedError('Exact OT solver with lazy=True not implemented')) - - else: - raise (NotImplementedError('Non regularized solver with unbalanced_type="{}" not implemented'.format(unbalanced_type))) - - else: - if unbalanced is None: - - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - if batch_size is None: - batch_size = 100 - - value_linear, log = empirical_sinkhorn2(X_a, X_b, reg, a, b, metric=metric, numIterMax=max_iter, stopThr=tol, - isLazy=True, batchSize=batch_size, verbose=verbose, log=True) - # compute potentials - potentials = (log["u"], log["v"]) - lazy_plan = log['lazy_plan'] - - else: - raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) - - res = OTResult(potentials=potentials, value=value, lazy_plan=lazy_plan, - value_linear=value_linear, plan=plan, status=status, backend=nx, log=log) - return res diff --git a/ot/unbalanced/Untitled.ipynb b/ot/unbalanced/Untitled.ipynb new file mode 100644 index 000000000..87737c676 --- /dev/null +++ b/ot/unbalanced/Untitled.ipynb @@ -0,0 +1,973 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 232, + "id": "bab1ddba", + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "fa20df35", + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.ones(3, 3)\n", + "y = 2 * x\n", + "z = 3 * x\n", + "t = torch.stack([y, z], dim=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "e58f96d3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 3, 2])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "0b1cc563", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 3])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "7ec187d3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[2., 3.],\n", + " [2., 3.],\n", + " [2., 3.]],\n", + "\n", + " [[2., 3.],\n", + " [2., 3.],\n", + " [2., 3.]],\n", + "\n", + " [[2., 3.],\n", + " [2., 3.],\n", + " [2., 3.]]])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x[:, :, None] * t" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "041cc43d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1., 1., 1.],\n", + " [1., 1., 1.],\n", + " [1., 1., 1.]])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "03d039f4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[2., 3.],\n", + " [2., 3.],\n", + " [2., 3.]],\n", + "\n", + " [[2., 3.],\n", + " [2., 3.],\n", + " [2., 3.]],\n", + "\n", + " [[2., 3.],\n", + " [2., 3.],\n", + " [2., 3.]]])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t" + ] + }, + { + "cell_type": "code", + "execution_count": 390, + "id": "6c50108f", + "metadata": {}, + "outputs": [], + "source": [ + "a = torch.ones(3, 1)\n", + "b = torch.ones(5) * 2\n", + "c = torch.ones(5) * 3\n", + "d = torch.stack([b, c],dim=1)\n", + "m = torch.randn(3, 5)\n", + "\n", + "aa = torch.concat([a, a],dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 303, + "id": "14167c8d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2., 2., 2., 2., 2.],\n", + " [2., 2., 2., 2., 2.],\n", + " [2., 2., 2., 2., 2.]])" + ] + }, + "execution_count": 303, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a @ b[None, :]" + ] + }, + { + "cell_type": "code", + "execution_count": 304, + "id": "67f0db6c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.]])" + ] + }, + "execution_count": 304, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d" + ] + }, + { + "cell_type": "code", + "execution_count": 305, + "id": "924d115d", + "metadata": {}, + "outputs": [], + "source": [ + "e = torch.einsum('ij, jkm->mik', a, d[None, :, :])" + ] + }, + { + "cell_type": "code", + "execution_count": 306, + "id": "c32ee8cc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[2., 2., 2., 2., 2.],\n", + " [2., 2., 2., 2., 2.],\n", + " [2., 2., 2., 2., 2.]],\n", + "\n", + " [[3., 3., 3., 3., 3.],\n", + " [3., 3., 3., 3., 3.],\n", + " [3., 3., 3., 3., 3.]]])" + ] + }, + "execution_count": 306, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "e" + ] + }, + { + "cell_type": "code", + "execution_count": 307, + "id": "88abc5fb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2., 2., 2., 2., 2.],\n", + " [3., 3., 3., 3., 3.]])" + ] + }, + "execution_count": 307, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dd" + ] + }, + { + "cell_type": "code", + "execution_count": 308, + "id": "5fbd8e4b", + "metadata": {}, + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_17575/3808899050.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mdd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;34m(\u001b[0m\u001b[0mdd\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1" + ] + } + ], + "source": [ + "dd = d.T\n", + "(dd * e).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 309, + "id": "b6dd203c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.]])" + ] + }, + "execution_count": 309, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d" + ] + }, + { + "cell_type": "code", + "execution_count": 310, + "id": "aa0e4d3c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[20., 45.],\n", + " [20., 45.],\n", + " [20., 45.]])" + ] + }, + "execution_count": 310, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "f = torch.einsum('ijk, ki->ji', e, d)\n", + "f" + ] + }, + { + "cell_type": "code", + "execution_count": 289, + "id": "07c71321", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[2., 2., 2., 2., 2.],\n", + " [2., 2., 2., 2., 2.]],\n", + "\n", + " [[3., 3., 3., 3., 3.],\n", + " [3., 3., 3., 3., 3.]]])" + ] + }, + "execution_count": 289, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "e" + ] + }, + { + "cell_type": "code", + "execution_count": 290, + "id": "0d7076ce", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.]])" + ] + }, + "execution_count": 290, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d" + ] + }, + { + "cell_type": "code", + "execution_count": 320, + "id": "2ad6c8a0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[6., 9.],\n", + " [6., 9.],\n", + " [6., 9.],\n", + " [6., 9.],\n", + " [6., 9.]])" + ] + }, + "execution_count": 320, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eT = e.reshape(2, 5, 3)\n", + "g = torch.einsum('ijk, ki->ji', eT, a)\n", + "g" + ] + }, + { + "cell_type": "code", + "execution_count": 321, + "id": "496709df", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.3333, 0.3333],\n", + " [0.3333, 0.3333],\n", + " [0.3333, 0.3333],\n", + " [0.3333, 0.3333],\n", + " [0.3333, 0.3333]])" + ] + }, + "execution_count": 321, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d / g" + ] + }, + { + "cell_type": "code", + "execution_count": 324, + "id": "f5cb0a5b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.]])" + ] + }, + "execution_count": 324, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d" + ] + }, + { + "cell_type": "code", + "execution_count": 329, + "id": "20d8b409", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([10.8688, 24.4548])" + ] + }, + "execution_count": 329, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "res = torch.einsum('ik,kij,jk,ij->k', a, e, d, m)\n", + "res" + ] + }, + { + "cell_type": "code", + "execution_count": 330, + "id": "5b419e6a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1.],\n", + " [1.],\n", + " [1.]])" + ] + }, + "execution_count": 330, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a" + ] + }, + { + "cell_type": "code", + "execution_count": 331, + "id": "84cf81ac", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[2., 2., 2., 2., 2.],\n", + " [2., 2., 2., 2., 2.],\n", + " [2., 2., 2., 2., 2.]],\n", + "\n", + " [[3., 3., 3., 3., 3.],\n", + " [3., 3., 3., 3., 3.],\n", + " [3., 3., 3., 3., 3.]]])" + ] + }, + "execution_count": 331, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "e" + ] + }, + { + "cell_type": "code", + "execution_count": 332, + "id": "1521a955", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.]])" + ] + }, + "execution_count": 332, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d" + ] + }, + { + "cell_type": "code", + "execution_count": 365, + "id": "a84c2925", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]],\n", + "\n", + " [[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]]])" + ] + }, + "execution_count": 365, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(a[None, :,:] + d.T[:,None, :])" + ] + }, + { + "cell_type": "code", + "execution_count": 340, + "id": "346847db", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1.],\n", + " [1.],\n", + " [1.]])" + ] + }, + "execution_count": 340, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a" + ] + }, + { + "cell_type": "code", + "execution_count": 341, + "id": "da1574b9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.]])" + ] + }, + "execution_count": 341, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d" + ] + }, + { + "cell_type": "code", + "execution_count": 361, + "id": "5c25b074", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[3., 3., 3., 3., 3.],\n", + " [3., 3., 3., 3., 3.],\n", + " [3., 3., 3., 3., 3.]])" + ] + }, + "execution_count": 361, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a + d[:, 0]" + ] + }, + { + "cell_type": "code", + "execution_count": 366, + "id": "f10f2752", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.]])" + ] + }, + "execution_count": 366, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d" + ] + }, + { + "cell_type": "code", + "execution_count": 382, + "id": "4d48e54a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[4., 4., 4., 4., 4.],\n", + " [4., 4., 4., 4., 4.],\n", + " [4., 4., 4., 4., 4.],\n", + " [4., 4., 4., 4., 4.],\n", + " [4., 4., 4., 4., 4.]])" + ] + }, + "execution_count": 382, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d[:, 0][:, None] + d[:, 0][None, :]" + ] + }, + { + "cell_type": "code", + "execution_count": 381, + "id": "80975d51", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[4., 4., 4., 4., 4.],\n", + " [4., 4., 4., 4., 4.],\n", + " [4., 4., 4., 4., 4.],\n", + " [4., 4., 4., 4., 4.],\n", + " [4., 4., 4., 4., 4.]],\n", + "\n", + " [[6., 6., 6., 6., 6.],\n", + " [6., 6., 6., 6., 6.],\n", + " [6., 6., 6., 6., 6.],\n", + " [6., 6., 6., 6., 6.],\n", + " [6., 6., 6., 6., 6.]]])" + ] + }, + "execution_count": 381, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ee = d.T[:, :, None] + d.T[:, None, :]\n", + "ee" + ] + }, + { + "cell_type": "code", + "execution_count": 377, + "id": "6e7f9913", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 5, 5])" + ] + }, + "execution_count": 377, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ee.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 378, + "id": "c4bc1337", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 3, 5])" + ] + }, + "execution_count": 378, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "e.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 379, + "id": "70cb3f8b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1.],\n", + " [1.],\n", + " [1.]])" + ] + }, + "execution_count": 379, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a" + ] + }, + { + "cell_type": "code", + "execution_count": 380, + "id": "a93d3f3f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.]])" + ] + }, + "execution_count": 380, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d" + ] + }, + { + "cell_type": "code", + "execution_count": 391, + "id": "91f386ea", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1., 1.],\n", + " [1., 1.],\n", + " [1., 1.]])" + ] + }, + "execution_count": 391, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "aa" + ] + }, + { + "cell_type": "code", + "execution_count": 395, + "id": "99b2b52f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[2., 2., 2., 2., 2.],\n", + " [2., 2., 2., 2., 2.],\n", + " [2., 2., 2., 2., 2.]],\n", + "\n", + " [[3., 3., 3., 3., 3.],\n", + " [3., 3., 3., 3., 3.],\n", + " [3., 3., 3., 3., 3.]]])" + ] + }, + "execution_count": 395, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.einsum('ij, jkm->mik', a, d[None, :, :])" + ] + }, + { + "cell_type": "code", + "execution_count": 398, + "id": "aed12015", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2., 2., 2., 2., 2.],\n", + " [2., 2., 2., 2., 2.],\n", + " [2., 2., 2., 2., 2.]])" + ] + }, + "execution_count": 398, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a * d[:, 0]" + ] + }, + { + "cell_type": "code", + "execution_count": 397, + "id": "bca65b4e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.],\n", + " [2., 3.]])" + ] + }, + "execution_count": 397, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6187bb9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ot/unbalanced/__init__.py b/ot/unbalanced/__init__.py new file mode 100644 index 000000000..67ef6cdf8 --- /dev/null +++ b/ot/unbalanced/__init__.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +""" +Solvers related to Unbalanced Optimal Transport problems. + +""" + +# Author: Quang Huy Tran +# +# License: MIT License + +# All submodules and packages +from ._sinkhorn import (sinkhorn_knopp_unbalanced, + sinkhorn_unbalanced, + sinkhorn_stabilized_unbalanced, + sinkhorn_unbalanced2, + barycenter_unbalanced_sinkhorn, + barycenter_unbalanced_stabilized, + barycenter_unbalanced) + +from ._mm import (mm_unbalanced, mm_unbalanced2) + +from ._lbfgs import (_get_loss_unbalanced, lbfgsb_unbalanced, lbfgsb_unbalanced2) + +__all__ = ['sinkhorn_knopp_unbalanced', 'sinkhorn_unbalanced', 'sinkhorn_stabilized_unbalanced', + 'sinkhorn_unbalanced2', 'barycenter_unbalanced_sinkhorn', 'barycenter_unbalanced_stabilized', + 'barycenter_unbalanced', 'mm_unbalanced', 'mm_unbalanced2', '_get_loss_unbalanced', + 'lbfgsb_unbalanced', 'lbfgsb_unbalanced2'] diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py new file mode 100644 index 000000000..734924948 --- /dev/null +++ b/ot/unbalanced/_lbfgs.py @@ -0,0 +1,357 @@ +# -*- coding: utf-8 -*- +""" +Regularized Unbalanced OT solvers +""" + +# Author: Hicham Janati +# Laetitia Chapel +# Quang Huy Tran +# +# License: MIT License + +from __future__ import division +import warnings + +import numpy as np +from scipy.optimize import minimize, Bounds + +from ..backend import get_backend +from ..utils import list_to_array, get_parameter_pair + + +def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='kl'): + """ + return the loss function (scipy.optimize compatible) for regularized + unbalanced OT + """ + + m, n = M.shape + + def kl(p, q): + return np.sum(p * np.log(p / q + 1e-16)) - np.sum(p) + np.sum(q) + + def reg_l2(G): + return np.sum((G - c)**2) / 2 + + def grad_l2(G): + return G - c + + def reg_kl(G): + return kl(G, c) + + def grad_kl(G): + return np.log(G / c + 1e-16) + + if reg_div == 'kl': + reg_fun = reg_kl + grad_reg_fun = grad_kl + elif isinstance(reg_div, tuple): + reg_fun = reg_div[0] + grad_reg_fun = reg_div[1] + else: + reg_fun = reg_l2 + grad_reg_fun = grad_l2 + + def marg_l2(G): + return reg_m1 * 0.5 * np.sum((G.sum(1) - a)**2) + \ + reg_m2 * 0.5 * np.sum((G.sum(0) - b)**2) + + def grad_marg_l2(G): + return reg_m1 * np.outer((G.sum(1) - a), np.ones(n)) + \ + reg_m2 * np.outer(np.ones(m), (G.sum(0) - b)) + + def marg_kl(G): + return reg_m1 * kl(G.sum(1), a) + reg_m2 * kl(G.sum(0), b) + + def grad_marg_kl(G): + return reg_m1 * np.outer(np.log(G.sum(1) / a + 1e-16), np.ones(n)) + \ + reg_m2 * np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16)) + + def marg_tv(G): + return reg_m1 * np.sum(np.abs(G.sum(1) - a)) + \ + reg_m2 * np.sum(np.abs(G.sum(0) - b)) + + def grad_marg_tv(G): + return reg_m1 * np.outer(np.sign(G.sum(1) - a), np.ones(n)) + \ + reg_m2 * np.outer(np.ones(m), np.sign(G.sum(0) - b)) + + if regm_div == 'kl': + regm_fun = marg_kl + grad_regm_fun = grad_marg_kl + elif regm_div == 'tv': + regm_fun = marg_tv + grad_regm_fun = grad_marg_tv + else: + regm_fun = marg_l2 + grad_regm_fun = grad_marg_l2 + + def _func(G): + G = G.reshape((m, n)) + + # compute loss + val = np.sum(G * M) + reg * reg_fun(G) + regm_fun(G) + + # compute gradient + grad = M + reg * grad_reg_fun(G) + grad_regm_fun(G) + + return val, grad.ravel() + + return _func + + +def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, + stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False): + r""" + Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B. + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_{m2}} \cdot \mathrm{div_m}(\gamma^T \mathbf{1}, \mathbf{b}) + + s.t. + \gamma \geq 0 + + where: + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization + - :math:`\mathrm{div_m}` is a divergence, either Kullback-Leibler divergence, + or half-squared :math:`\ell_2` divergence, or Total variation + - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler divergence, + or half-squared :math:`\ell_2` divergence + + The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + M : array-like (dim_a, dim_b) + loss matrix + reg: float + regularization term >=0 + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term >= 0, but cannot be infinity. + If reg_m is a scalar or an indexable object of length 1, + then the same reg_m is applied to both marginal relaxations. + If reg_m is an array, it must be a Numpy array. + reg_div: string, optional + Divergence used for regularization. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) or a tuple + of two calable functions returning the reg term and its derivative. + Note that the callable functions should be able to handle numpy arrays + and not tesors from the backend + regm_div: string, optional + Divergence to quantify the difference between the marginals. + Can take three values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) or 'tv' (Total Variation) + G0: array-like (dim_a, dim_b) + Initialization of the transport matrix + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : (dim_a, dim_b) array-like + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + >>> import ot + >>> import numpy as np + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[1., 36.],[9., 4.]] + >>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='kl', regm_div='kl'), 2) + array([[0.45, 0. ], + [0. , 0.34]]) + >>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='l2', regm_div='l2'), 2) + array([[0.4, 0. ], + [0. , 0.1]]) + + References + ---------- + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + See Also + -------- + ot.lp.emd2 : Unregularized OT loss + ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss + """ + + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + M0 = M + + # convert to numpy + a, b, M = nx.to_numpy(a, b, M) + G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) + if reg > 0: # regularized case + c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) + else: # unregularized case + c = 0 + + # wrap the callable function to handle numpy arrays + if isinstance(reg_div, tuple): + f0, df0 = reg_div + try: + f0(G0) + df0(G0) + except BaseException: + warnings.warn("The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead") + + def f(x): + return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0))) + + def df(x): + return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0))) + + reg_div = (f, df) + + reg_m1, reg_m2 = get_parameter_pair(reg_m) + _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) + + res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf), + tol=stopThr, options=dict(maxiter=numItermax, disp=verbose)) + + G = nx.from_numpy(res.x.reshape(M.shape), type_as=M0) + + if log: + log = {'cost': nx.sum(G * M), 'res': res} + log['total_cost'] = nx.from_numpy(res.fun, type_as=M0) + return G, log + else: + return G + + +def lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', + G0=None, returnCost="linear", numItermax=1000, stopThr=1e-15, + method='L-BFGS-B', verbose=False, log=False): + r""" + Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B. + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_{m2}} \cdot \mathrm{div_m}(\gamma^T \mathbf{1}, \mathbf{b}) + + s.t. + \gamma \geq 0 + + where: + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization + - :math:`\mathrm{div_m}` is a divergence, either Kullback-Leibler divergence, + or half-squared :math:`\ell_2` divergence, or Total variation + - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler divergence, + or half-squared :math:`\ell_2` divergence + + The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + M : array-like (dim_a, dim_b) + loss matrix + reg: float + regularization term >=0 + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term >= 0, but cannot be infinity. + If reg_m is a scalar or an indexable object of length 1, + then the same reg_m is applied to both marginal relaxations. + If reg_m is an array, it must be a Numpy array. + reg_div: string, optional + Divergence used for regularization. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) or a tuple + of two calable functions returning the reg term and its derivative. + Note that the callable functions should be able to handle numpy arrays + and not tesors from the backend + regm_div: string, optional + Divergence to quantify the difference between the marginals. + Can take three values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) or 'tv' (Total Variation) + G0: array-like (dim_a, dim_b) + Initialization of the transport matrix + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : (dim_a, dim_b) array-like + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + >>> import ot + >>> import numpy as np + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[1., 36.],[9., 4.]] + >>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='kl', regm_div='kl'), 2) + array([[0.45, 0. ], + [0. , 0.34]]) + >>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='l2', regm_div='l2'), 2) + array([[0.4, 0. ], + [0. , 0.1]]) + + References + ---------- + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + See Also + -------- + ot.lp.emd2 : Unregularized OT loss + ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss + """ + + _, log_lbfgs = lbfgsb_unbalanced(a=a, b=b, M=M, reg=reg, reg_m=reg_m, c=c, + reg_div=reg_div, regm_div=regm_div, G0=G0, + numItermax=numItermax, stopThr=stopThr, + method=method, verbose=verbose, log=True) + + if returnCost == "linear": + cost = log_lbfgs['cost'] + elif returnCost == "total": + cost = log_lbfgs['total_cost'] + else: + raise ValueError("Unknown returnCost = {}".format(returnCost)) + + if log: + return cost, log_lbfgs + else: + return cost diff --git a/ot/unbalanced/_mm.py b/ot/unbalanced/_mm.py new file mode 100644 index 000000000..464c89f78 --- /dev/null +++ b/ot/unbalanced/_mm.py @@ -0,0 +1,288 @@ +# -*- coding: utf-8 -*- +""" +Regularized Unbalanced OT solvers +""" + +# Author: Hicham Janati +# Laetitia Chapel +# Quang Huy Tran +# +# License: MIT License + +from __future__ import division + +from ..backend import get_backend +from ..utils import list_to_array, get_parameter_pair + + +def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000, + stopThr=1e-15, verbose=False, log=False): + r""" + Solve the unbalanced optimal transport problem and return the OT plan. + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) + + s.t. + \gamma \geq 0 + + where: + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization + - div is a divergence, either Kullback-Leibler or half-squared :math:`\ell_2` divergence + + The algorithm used for solving the problem is a maximization- + minimization algorithm as proposed in :ref:`[41] ` + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + M : array-like (dim_a, dim_b) + loss matrix + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term >= 0, but cannot be infinity. + If reg_m is a scalar or an indexable object of length 1, + then the same reg_m is applied to both marginal relaxations. + If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg : float, optional (default = 0) + Regularization term >= 0. + By default, solve the unregularized problem + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + div: string, optional + Divergence to quantify the difference between the marginals. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + G0: array-like (dim_a, dim_b) + Initialization of the transport matrix + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + Returns + ------- + gamma : (dim_a, dim_b) array-like + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + >>> import ot + >>> import numpy as np + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[1., 36.],[9., 4.]] + >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 5, div='kl'), 2) + array([[0.45, 0. ], + [0. , 0.34]]) + >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 5, div='l2'), 2) + array([[0.4, 0. ], + [0. , 0.1]]) + + + .. _references-regpath: + References + ---------- + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + See Also + -------- + ot.lp.emd : Unregularized OT + ot.unbalanced.sinkhorn_unbalanced : Entropic regularized OT + """ + + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + + dim_a, dim_b = M.shape + + if len(a) == 0: + a = nx.ones(dim_a, type_as=M) / dim_a + if len(b) == 0: + b = nx.ones(dim_b, type_as=M) / dim_b + + G = a[:, None] * b[None, :] if G0 is None else G0 + if reg > 0: # regularized case + c = a[:, None] * b[None, :] if c is None else c + else: # unregularized case + c = 0 + + reg_m1, reg_m2 = get_parameter_pair(reg_m) + + if log: + log = {'err': [], 'G': []} + + if div == 'kl': + sum_r = reg + reg_m1 + reg_m2 + r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r + K = (a[:, None]**r1) * (b[None, :]**r2) * (c**r) * nx.exp(- M / sum_r) + elif div == 'l2': + K = reg_m1 * a[:, None] + reg_m2 * b[None, :] + reg * c - M + K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M)) + else: + raise ValueError("Unknown div = {}. Must be either 'kl' or 'l2'".format(div)) + + for i in range(numItermax): + Gprev = G + + if div == 'kl': + Gd = (nx.sum(G, 1, keepdims=True)**r1) * (nx.sum(G, 0, keepdims=True)**r2) + 1e-16 + G = K * G**(r1 + r2) / Gd + elif div == 'l2': + Gd = reg_m1 * nx.sum(G, 1, keepdims=True) + \ + reg_m2 * nx.sum(G, 0, keepdims=True) + reg * G + 1e-16 + G = K * G / Gd + + err = nx.sqrt(nx.sum((G - Gprev) ** 2)) + if log: + log['err'].append(err) + log['G'].append(G) + if verbose: + print('{:5d}|{:8e}|'.format(i, err)) + if err < stopThr: + break + + if log: + linear_cost = nx.sum(G * M) + log['cost'] = linear_cost + + m1, m2 = nx.sum(G, 1), nx.sum(G, 0) + if div == "kl": + cost = linear_cost + reg_m1 * nx.kl_div(m1, a, mass=True) + reg_m2 * nx.kl_div(m2, b, mass=True) + if reg > 0: + cost = cost + reg * nx.kl_div(G, c, mass=True) + else: + cost = linear_cost + reg_m1 * 0.5 * nx.sum((m1 - a)**2) + reg_m2 * 0.5 * nx.sum((m2 - b)**2) + if reg > 0: + cost = cost + reg * 0.5 * nx.sum((G - c)**2) + + log["total_cost"] = cost + + return G, log + else: + return G + + +def mm_unbalanced2(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, returnCost="linear", + numItermax=1000, stopThr=1e-15, verbose=False, log=False): + r""" + Solve the unbalanced optimal transport problem and return the OT plan. + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) + + s.t. + \gamma \geq 0 + + where: + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization + - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or half-squared :math:`\ell_2` divergence + + The algorithm used for solving the problem is a maximization- + minimization algorithm as proposed in :ref:`[41] ` + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + M : array-like (dim_a, dim_b) + loss matrix + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term >= 0, but cannot be infinity. + If reg_m is a scalar or an indexable object of length 1, + then the same reg_m is applied to both marginal relaxations. + If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg : float, optional (default = 0) + Entropy regularization term >= 0. + By default, solve the unregularized problem + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = mathbf{a} mathbf{b}^T`. + div: string, optional + Divergence to quantify the difference between the marginals. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + G0: array-like (dim_a, dim_b) + Initialization of the transport matrix + returnCost: string, optional (default = "linear") + If returnCost = "linear", then return the linear part of the unbalanced OT loss. + If returnCost = "total", then return total unbalanced OT loss. + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + ot_distance : array-like + the OT cost between :math:`\mathbf{a}` and :math:`\mathbf{b}` + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + >>> import ot + >>> import numpy as np + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[1., 36.],[9., 4.]] + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 5, div='l2'), 2) + 0.8 + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 5, div='kl'), 2) + 1.79 + + References + ---------- + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + See Also + -------- + ot.lp.emd2 : Unregularized OT loss + ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss + """ + + _, log_mm = mm_unbalanced(a, b, M, reg_m, c=c, reg=reg, div=div, G0=G0, + numItermax=numItermax, stopThr=stopThr, + verbose=verbose, log=True) + + if returnCost == "linear": + cost = log_mm['cost'] + elif returnCost == "total": + cost = log_mm['total_cost'] + else: + raise ValueError("Unknown returnCost = {}".format(returnCost)) + + if log: + return cost, log_mm + else: + return cost diff --git a/ot/unbalanced.py b/ot/unbalanced/_sinkhorn.py similarity index 67% rename from ot/unbalanced.py rename to ot/unbalanced/_sinkhorn.py index c39888a31..f49f70c46 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced/_sinkhorn.py @@ -12,16 +12,13 @@ from __future__ import division import warnings -import numpy as np -from scipy.optimize import minimize, Bounds +from ..backend import get_backend +from ..utils import list_to_array, get_parameter_pair -from .backend import get_backend -from .utils import list_to_array, get_parameter_pair - -def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', - reg_type="entropy", warmstart=None, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, **kwargs): +def sinkhorn_unbalanced(a, b, M, reg, reg_m, c=None, method='sinkhorn', + warmstart=None, numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): r""" Solve the unbalanced entropic regularization optimal transport problem and return the OT plan @@ -30,7 +27,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', .. math:: W = \min_\gamma \ \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg} \cdot \Omega(\gamma) + + \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -40,8 +37,8 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized @@ -67,15 +64,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + c : array-like (dim_a, dim_b), optional (default=None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or 'sinkhorn_reg_scaling', see those function for specific parameters - reg_type : string, optional - Regularizer term. Can take two values: - 'entropy' (negative entropy) - :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or - 'kl' (Kullback-Leibler) - :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors). @@ -143,20 +137,20 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', """ if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) @@ -164,9 +158,9 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', raise ValueError("Unknown method '%s'." % method) -def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', - reg_type="entropy", warmstart=None, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, **kwargs): +def sinkhorn_unbalanced2(a, b, M, reg, reg_m, c=None, method='sinkhorn', + warmstart=None, numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -175,7 +169,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg} \cdot \Omega(\gamma) + + \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -184,8 +178,8 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized @@ -211,15 +205,12 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + c : array-like (dim_a, dim_b), optional (default=None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or 'sinkhorn_reg_scaling', see those function for specific parameterss - reg_type : string, optional - Regularizer term. Can take two values: - 'entropy' (negative entropy) - :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or - 'kl' (Kullback-Leibler) - :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors). @@ -281,19 +272,19 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', if len(b.shape) < 2: if method.lower() == 'sinkhorn': - res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - res = sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + res = sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') - res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) @@ -307,19 +298,19 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', else: if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) @@ -327,7 +318,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', raise ValueError('Unknown method %s.' % method) -def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", +def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c=None, warmstart=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" @@ -338,7 +329,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg} \cdot \Omega(\gamma) + + \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -348,8 +339,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] ` @@ -374,12 +365,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). - reg_type : string, optional - Regularizer term. Can take two values: - 'entropy' (negative entropy) - :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or - 'kl' (Kullback-Leibler) - :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + c : array-like (dim_a, dim_b), optional (default=None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors). @@ -452,7 +440,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", reg_m1, reg_m2 = get_parameter_pair(reg_m) if log: - log = {'err': []} + dict_log = {'err': []} # we assume that no distances are null except those of the diagonal of # distances @@ -467,10 +455,11 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", else: u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - if reg_type == "kl": - K = nx.exp(-M / reg) * a.reshape(-1)[:, None] * b.reshape(-1)[None, :] - elif reg_type == "entropy": + if n_hists: K = nx.exp(-M / reg) + else: + c = a[:, None] * b[None, :] if c is None else c + K = nx.exp(-M / reg) * c fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 @@ -504,7 +493,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", ) err = 0.5 * (err_u + err_v) if log: - log['err'].append(err) + dict_log['err'].append(err) if verbose: if i % 50 == 0: print( @@ -514,25 +503,35 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", break if log: - log['logu'] = nx.log(u + 1e-300) - log['logv'] = nx.log(v + 1e-300) + dict_log['logu'] = nx.log(u + 1e-300) + dict_log['logv'] = nx.log(v + 1e-300) if n_hists: # return only loss res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: - return res, log + return res, dict_log else: return res else: # return OT matrix + plan = u[:, None] * K * v[None, :] if log: - return u[:, None] * K * v[None, :], log + linear_cost = nx.sum(plan * M) + total_cost = linear_cost + reg * nx.kl_div(plan, c) + if reg_m1 != float("inf"): + total_cost = total_cost + reg_m1 * nx.kl_div(nx.sum(plan, 1), a) + if reg_m2 != float("inf"): + total_cost = total_cost + reg_m2 * nx.kl_div(nx.sum(plan, 0), b) + dict_log["cost"] = linear_cost + dict_log["total_cost"] = total_cost + + return plan, dict_log else: - return u[:, None] * K * v[None, :] + return plan -def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", +def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, c=None, warmstart=None, tau=1e5, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): @@ -545,7 +544,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg} \cdot \Omega(\gamma) + + \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -555,8 +554,8 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\mathbf{c}` is a reference distribution for the regularization - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized @@ -582,12 +581,9 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). - reg_type : string, optional - Regularizer term. Can take two values: - 'entropy' (negative entropy) - :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or - 'kl' (Kullback-Leibler) - :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + c : array-like (dim_a, dim_b), optional (default=None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors). @@ -660,7 +656,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", reg_m1, reg_m2 = get_parameter_pair(reg_m) if log: - log = {'err': []} + dict_log = {'err': []} # we assume that no distances are null except those of the diagonal of # distances @@ -675,12 +671,11 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", else: u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - if reg_type == "kl": - log_ab = nx.log(a + 1e-16).reshape(-1)[:, None] + nx.log(b + 1e-16).reshape(-1)[None, :] - M0 = M - reg * log_ab - else: + if n_hists: M0 = M - + else: + c = a[:, None] * b[None, :] if c is None else c + M0 = M - reg * nx.log(c) K = nx.exp(-M0 / reg) fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 @@ -736,7 +731,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. ) if log: - log['err'].append(err) + dict_log['err'].append(err) if verbose: if cpt % 200 == 0: print( @@ -755,8 +750,8 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", logu = alpha / reg + nx.log(u) logv = beta / reg + nx.log(v) if log: - log['logu'] = logu - log['logv'] = logv + dict_log['logu'] = logu + dict_log['logv'] = logv if n_hists: # return only loss res = nx.logsumexp( nx.log(M + 1e-100)[:, :, None] @@ -767,16 +762,25 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", ) res = nx.exp(res) if log: - return res, log + return res, dict_log else: return res else: # return OT matrix - ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M0 / reg) + plan = nx.exp(logu[:, None] + logv[None, :] - M0 / reg) if log: - return ot_matrix, log + linear_cost = nx.sum(plan * M) + total_cost = linear_cost + reg * nx.kl_div(plan, c) + if reg_m1 != float("inf"): + total_cost = total_cost + reg_m1 * nx.kl_div(nx.sum(plan, 1), a) + if reg_m2 != float("inf"): + total_cost = total_cost + reg_m2 * nx.kl_div(nx.sum(plan, 0), b) + dict_log["cost"] = linear_cost + dict_log["total_cost"] = total_cost + + return plan, dict_log else: - return ot_matrix + return plan def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, @@ -1147,477 +1151,3 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) - - -def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000, - stopThr=1e-15, verbose=False, log=False): - r""" - Solve the unbalanced optimal transport problem and return the OT plan. - The function solves the following optimization problem: - - .. math:: - W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + - \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + - \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) - - s.t. - \gamma \geq 0 - - where: - - - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target - unbalanced distributions - - :math:`\mathbf{c}` is a reference distribution for the regularization - - div is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence - - The algorithm used for solving the problem is a maximization- - minimization algorithm as proposed in :ref:`[41] ` - - Parameters - ---------- - a : array-like (dim_a,) - Unnormalized histogram of dimension `dim_a` - b : array-like (dim_b,) - Unnormalized histogram of dimension `dim_b` - M : array-like (dim_a, dim_b) - loss matrix - reg_m: float or indexable object of length 1 or 2 - Marginal relaxation term >= 0, but cannot be infinity. - If reg_m is a scalar or an indexable object of length 1, - then the same reg_m is applied to both marginal relaxations. - If reg_m is an array, it must have the same backend as input arrays (a, b, M). - reg : float, optional (default = 0) - Regularization term >= 0. - By default, solve the unregularized problem - c : array-like (dim_a, dim_b), optional (default = None) - Reference measure for the regularization. - If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. - div: string, optional - Divergence to quantify the difference between the marginals. - Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) - G0: array-like (dim_a, dim_b) - Initialization of the transport matrix - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (> 0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - Returns - ------- - gamma : (dim_a, dim_b) array-like - Optimal transportation matrix for the given parameters - log : dict - log dictionary returned only if `log` is `True` - - Examples - -------- - >>> import ot - >>> import numpy as np - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> M=[[1., 36.],[9., 4.]] - >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 5, div='kl'), 2) - array([[0.45, 0. ], - [0. , 0.34]]) - >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 5, div='l2'), 2) - array([[0.4, 0. ], - [0. , 0.1]]) - - - .. _references-regpath: - References - ---------- - .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). - Unbalanced optimal transport through non-negative penalized - linear regression. NeurIPS. - See Also - -------- - ot.lp.emd : Unregularized OT - ot.unbalanced.sinkhorn_unbalanced : Entropic regularized OT - """ - - M, a, b = list_to_array(M, a, b) - nx = get_backend(M, a, b) - - dim_a, dim_b = M.shape - - if len(a) == 0: - a = nx.ones(dim_a, type_as=M) / dim_a - if len(b) == 0: - b = nx.ones(dim_b, type_as=M) / dim_b - - G = a[:, None] * b[None, :] if G0 is None else G0 - c = a[:, None] * b[None, :] if c is None else c - - reg_m1, reg_m2 = get_parameter_pair(reg_m) - - if log: - log = {'err': [], 'G': []} - - if div not in ["kl", "l2"]: - warnings.warn("The div parameter should be either equal to 'kl' or \ - 'l2': it has been set to 'kl'.") - div = 'kl' - - if div == 'kl': - sum_r = reg + reg_m1 + reg_m2 - r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r - K = (a[:, None]**r1) * (b[None, :]**r2) * (c**r) * nx.exp(- M / sum_r) - elif div == 'l2': - K = reg_m1 * a[:, None] + reg_m2 * b[None, :] + reg * c - M - K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M)) - - for i in range(numItermax): - Gprev = G - - if div == 'kl': - Gd = (nx.sum(G, 1, keepdims=True)**r1) * (nx.sum(G, 0, keepdims=True)**r2) + 1e-16 - G = K * G**(r1 + r2) / Gd - elif div == 'l2': - Gd = reg_m1 * nx.sum(G, 1, keepdims=True) + \ - reg_m2 * nx.sum(G, 0, keepdims=True) + reg * G + 1e-16 - G = K * G / Gd - - err = nx.sqrt(nx.sum((G - Gprev) ** 2)) - if log: - log['err'].append(err) - log['G'].append(G) - if verbose: - print('{:5d}|{:8e}|'.format(i, err)) - if err < stopThr: - break - - if log: - log['cost'] = nx.sum(G * M) - return G, log - else: - return G - - -def mm_unbalanced2(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000, - stopThr=1e-15, verbose=False, log=False): - r""" - Solve the unbalanced optimal transport problem and return the OT plan. - The function solves the following optimization problem: - - .. math:: - W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + - \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + - \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) - - s.t. - \gamma \geq 0 - - where: - - - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target - unbalanced distributions - - :math:`\mathbf{c}` is a reference distribution for the regularization - - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence - - The algorithm used for solving the problem is a maximization- - minimization algorithm as proposed in :ref:`[41] ` - - Parameters - ---------- - a : array-like (dim_a,) - Unnormalized histogram of dimension `dim_a` - b : array-like (dim_b,) - Unnormalized histogram of dimension `dim_b` - M : array-like (dim_a, dim_b) - loss matrix - reg_m: float or indexable object of length 1 or 2 - Marginal relaxation term >= 0, but cannot be infinity. - If reg_m is a scalar or an indexable object of length 1, - then the same reg_m is applied to both marginal relaxations. - If reg_m is an array, it must have the same backend as input arrays (a, b, M). - reg : float, optional (default = 0) - Entropy regularization term >= 0. - By default, solve the unregularized problem - c : array-like (dim_a, dim_b), optional (default = None) - Reference measure for the regularization. - If None, then use `\mathbf{c} = mathbf{a} mathbf{b}^T`. - div: string, optional - Divergence to quantify the difference between the marginals. - Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) - G0: array-like (dim_a, dim_b) - Initialization of the transport matrix - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (> 0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - - Returns - ------- - ot_distance : array-like - the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}` - log : dict - log dictionary returned only if `log` is `True` - - Examples - -------- - >>> import ot - >>> import numpy as np - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> M=[[1., 36.],[9., 4.]] - >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 5, div='l2'), 2) - 0.8 - >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 5, div='kl'), 2) - 1.79 - - References - ---------- - .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). - Unbalanced optimal transport through non-negative penalized - linear regression. NeurIPS. - See Also - -------- - ot.lp.emd2 : Unregularized OT loss - ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss - """ - _, log_mm = mm_unbalanced(a, b, M, reg_m, c=c, reg=reg, div=div, G0=G0, - numItermax=numItermax, stopThr=stopThr, - verbose=verbose, log=True) - - if log: - return log_mm['cost'], log_mm - else: - return log_mm['cost'] - - -def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='kl'): - """ - return the loss function (scipy.optimize compatible) for regularized - unbalanced OT - """ - - m, n = M.shape - - def kl(p, q): - return np.sum(p * np.log(p / q + 1e-16)) - np.sum(p) + np.sum(q) - - def reg_l2(G): - return np.sum((G - c)**2) / 2 - - def grad_l2(G): - return G - c - - def reg_kl(G): - return kl(G, c) - - def grad_kl(G): - return np.log(G / c + 1e-16) - - def reg_entropy(G): - return np.sum(G * np.log(G + 1e-16)) - np.sum(G) - - def grad_entropy(G): - return np.log(G + 1e-16) - - if reg_div == 'kl': - reg_fun = reg_kl - grad_reg_fun = grad_kl - elif reg_div == 'entropy': - reg_fun = reg_entropy - grad_reg_fun = grad_entropy - elif isinstance(reg_div, tuple): - reg_fun = reg_div[0] - grad_reg_fun = reg_div[1] - else: - reg_fun = reg_l2 - grad_reg_fun = grad_l2 - - def marg_l2(G): - return reg_m1 * 0.5 * np.sum((G.sum(1) - a)**2) + \ - reg_m2 * 0.5 * np.sum((G.sum(0) - b)**2) - - def grad_marg_l2(G): - return reg_m1 * np.outer((G.sum(1) - a), np.ones(n)) + \ - reg_m2 * np.outer(np.ones(m), (G.sum(0) - b)) - - def marg_kl(G): - return reg_m1 * kl(G.sum(1), a) + reg_m2 * kl(G.sum(0), b) - - def grad_marg_kl(G): - return reg_m1 * np.outer(np.log(G.sum(1) / a + 1e-16), np.ones(n)) + \ - reg_m2 * np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16)) - - def marg_tv(G): - return reg_m1 * np.sum(np.abs(G.sum(1) - a)) + \ - reg_m2 * np.sum(np.abs(G.sum(0) - b)) - - def grad_marg_tv(G): - return reg_m1 * np.outer(np.sign(G.sum(1) - a), np.ones(n)) + \ - reg_m2 * np.outer(np.ones(m), np.sign(G.sum(0) - b)) - - if regm_div == 'kl': - regm_fun = marg_kl - grad_regm_fun = grad_marg_kl - elif regm_div == 'tv': - regm_fun = marg_tv - grad_regm_fun = grad_marg_tv - else: - regm_fun = marg_l2 - grad_regm_fun = grad_marg_l2 - - def _func(G): - G = G.reshape((m, n)) - - # compute loss - val = np.sum(G * M) + reg * reg_fun(G) + regm_fun(G) - - # compute gradient - grad = M + reg * grad_reg_fun(G) + grad_regm_fun(G) - - return val, grad.ravel() - - return _func - - -def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, - stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False): - r""" - Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B. - The function solves the following optimization problem: - - .. math:: - W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) - \mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + - \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) - - s.t. - \gamma \geq 0 - - where: - - - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target - unbalanced distributions - - :math:`\mathbf{c}` is a reference distribution for the regularization - - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence - - The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize - - Parameters - ---------- - a : array-like (dim_a,) - Unnormalized histogram of dimension `dim_a` - b : array-like (dim_b,) - Unnormalized histogram of dimension `dim_b` - M : array-like (dim_a, dim_b) - loss matrix - reg: float - regularization term >=0 - c : array-like (dim_a, dim_b), optional (default = None) - Reference measure for the regularization. - If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. - reg_m: float or indexable object of length 1 or 2 - Marginal relaxation term >= 0, but cannot be infinity. - If reg_m is a scalar or an indexable object of length 1, - then the same reg_m is applied to both marginal relaxations. - If reg_m is an array, it must be a Numpy array. - reg_div: string, optional - Divergence used for regularization. - Can take three values: 'entropy' (negative entropy), or - 'kl' (Kullback-Leibler) or 'l2' (quadratic) or a tuple - of two calable functions returning the reg term and its derivative. - Note that the callable functions should be able to handle numpy arrays - and not tesors from the backend - regm_div: string, optional - Divergence to quantify the difference between the marginals. - Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) - G0: array-like (dim_a, dim_b) - Initialization of the transport matrix - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (> 0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - - Returns - ------- - gamma : (dim_a, dim_b) array-like - Optimal transportation matrix for the given parameters - log : dict - log dictionary returned only if `log` is `True` - - Examples - -------- - >>> import ot - >>> import numpy as np - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> M=[[1., 36.],[9., 4.]] - >>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='kl', regm_div='kl'), 2) - array([[0.45, 0. ], - [0. , 0.34]]) - >>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='l2', regm_div='l2'), 2) - array([[0.4, 0. ], - [0. , 0.1]]) - - References - ---------- - .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). - Unbalanced optimal transport through non-negative penalized - linear regression. NeurIPS. - See Also - -------- - ot.lp.emd2 : Unregularized OT loss - ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss - """ - - M, a, b = list_to_array(M, a, b) - nx = get_backend(M, a, b) - M0 = M - - # convert to numpy - a, b, M = nx.to_numpy(a, b, M) - G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) - c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) - - # wrap the callable function to handle numpy arrays - if isinstance(reg_div, tuple): - f0, df0 = reg_div - try: - f0(G0) - df0(G0) - except BaseException: - warnings.warn("The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead") - - def f(x): - return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0))) - - def df(x): - return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0))) - - reg_div = (f, df) - - reg_m1, reg_m2 = get_parameter_pair(reg_m) - _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) - - res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf), - tol=stopThr, options=dict(maxiter=numItermax, disp=verbose)) - - G = nx.from_numpy(res.x.reshape(M.shape), type_as=M0) - - if log: - log = {'loss': nx.from_numpy(res.fun, type_as=M0), 'res': res} - return G, log - else: - return G diff --git a/test/unbalanced/__init__.py b/test/unbalanced/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/unbalanced/test_lbfgs.py b/test/unbalanced/test_lbfgs.py new file mode 100644 index 000000000..24f127d06 --- /dev/null +++ b/test/unbalanced/test_lbfgs.py @@ -0,0 +1,126 @@ +"""Tests for module Unbalanced OT with entropy regularization""" + +# Author: Hicham Janati +# Laetitia Chapel +# Quang Huy Tran +# +# License: MIT License + + +import itertools +import numpy as np +import ot +import pytest + + +@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2'], ['kl', 'l2'], ['linear', 'total'])) +def test_lbfgsb_unbalanced(nx, reg_div, regm_div, returnCost): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + + M = ot.dist(xs, xt) + + a = ot.unif(5) + b = ot.unif(6) + + G, log = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, 10, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + loss, _ = ot.unbalanced.lbfgsb_unbalanced2(a, b, M, 1, 10, + reg_div=reg_div, regm_div=regm_div, + returnCost=returnCost, log=True, verbose=False) + + ab, bb, Mb = nx.from_numpy(a, b, M) + + Gb, log = ot.unbalanced.lbfgsb_unbalanced(ab, bb, Mb, 1, 10, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + loss0, log = ot.unbalanced.lbfgsb_unbalanced2(ab, bb, Mb, 1, 10, + reg_div=reg_div, regm_div=regm_div, + returnCost=returnCost, log=True, verbose=False) + + np.testing.assert_allclose(G, nx.to_numpy(Gb)) + np.testing.assert_allclose(loss, nx.to_numpy(loss0), atol=1e-06) + + +@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2'], ['kl', 'l2'], ['linear', 'total'])) +def test_lbfgsb_unbalanced_relaxation_parameters(nx, reg_div, regm_div, returnCost): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + + M = ot.dist(xs, xt) + + a = ot.unif(5) + b = ot.unif(6) + + a, b, M = nx.from_numpy(a, b, M) + + reg_m = 10 + full_list_reg_m = [reg_m, reg_m] + full_tuple_reg_m = (reg_m, reg_m) + tuple_reg_m, list_reg_m = (reg_m), [reg_m] + np1_reg_m = reg_m * np.ones(1) + np2_reg_m = reg_m * np.ones(2) + + list_options = [np1_reg_m, np2_reg_m, full_tuple_reg_m, + tuple_reg_m, full_list_reg_m, list_reg_m] + + G = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, reg_m=reg_m, + reg_div=reg_div, regm_div=regm_div, + log=False, verbose=False) + loss = ot.unbalanced.lbfgsb_unbalanced2(a, b, M, 1, + reg_m=reg_m, reg_div=reg_div, regm_div=regm_div, + returnCost=returnCost, log=False, verbose=False) + + for opt in list_options: + G0 = ot.unbalanced.lbfgsb_unbalanced( + a, b, M, 1, reg_m=opt, reg_div=reg_div, + regm_div=regm_div, log=False, verbose=False + ) + loss0 = ot.unbalanced.lbfgsb_unbalanced2( + a, b, M, 1, reg_m=opt, reg_div=reg_div, + regm_div=regm_div, returnCost=returnCost, + log=False, verbose=False + ) + + np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-06) + + +@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2'], ['kl', 'l2'], ['linear', 'total'])) +def test_lbfgsb_reference_measure(nx, reg_div, regm_div, returnCost): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + M = ot.dist(xs, xt) + a = ot.unif(5) + b = ot.unif(6) + + a, b, M = nx.from_numpy(a, b, M) + c = a[:, None] * b[None, :] + + G, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=None, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + loss, _ = ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=1, reg_m=10, c=None, + reg_div=reg_div, regm_div=regm_div, + returnCost=returnCost, log=True, verbose=False) + + G0, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=c, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + + loss0, _ = ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=1, reg_m=10, c=c, + reg_div=reg_div, regm_div=regm_div, + returnCost=returnCost, log=True, verbose=False) + + np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-06) diff --git a/test/unbalanced/test_mm.py b/test/unbalanced/test_mm.py new file mode 100644 index 000000000..956f69de9 --- /dev/null +++ b/test/unbalanced/test_mm.py @@ -0,0 +1,164 @@ +"""Tests for module Unbalanced OT with entropy regularization""" + +# Author: Hicham Janati +# Laetitia Chapel +# Quang Huy Tran +# +# License: MIT License + + +import numpy as np +import ot +import pytest + + +@pytest.mark.parametrize("div", ["kl", "l2"]) +def test_mm_convergence(nx, div): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + reg_m = 100 + a, b, M = nx.from_numpy(a_np, b_np, M) + + G, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div=div, + verbose=False, log=True) + _, log = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div=div, verbose=True, log=True) + linear_cost = nx.to_numpy(log["cost"]) + + # check if the marginals come close to the true ones when large reg + np.testing.assert_allclose(np.sum(nx.to_numpy(G), 1), a_np, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G), 0), b_np, atol=1e-03) + + # check if mm_unbalanced2 returns the correct loss + np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), linear_cost, atol=1e-5) + + # check in case no histogram is provided + a_np, b_np = np.array([]), np.array([]) + a, b = nx.from_numpy(a_np, b_np) + + G_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div=div, verbose=False) + np.testing.assert_allclose(nx.to_numpy(G_null), nx.to_numpy(G)) + + # test when G0 is given + G0 = ot.emd(a, b, M) + G0_np = nx.to_numpy(G0) + reg_m = 10000 + G = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div=div, G0=G0, verbose=False) + np.testing.assert_allclose(G0_np, nx.to_numpy(G), atol=1e-05) + + +@pytest.mark.parametrize("div", ["kl", "l2"]) +def test_mm_relaxation_parameters(nx, div): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + + reg = 1e-2 + + reg_m = 100 + full_list_reg_m = [reg_m, reg_m] + full_tuple_reg_m = (reg_m, reg_m) + tuple_reg_m, list_reg_m = (reg_m), [reg_m] + nx1_reg_m = reg_m * nx.ones(1) + nx2_reg_m = reg_m * nx.ones(2) + + list_options = [nx1_reg_m, nx2_reg_m, full_tuple_reg_m, + tuple_reg_m, full_list_reg_m, list_reg_m] + + G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, reg=reg, + div=div, verbose=False, log=True) + loss_0 = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, reg=reg, + div=div, verbose=True) + ) + + for opt in list_options: + G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=opt, + reg=reg, div=div, + verbose=False, log=True) + loss_1 = nx.to_numpy( + ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=opt, + reg=reg, div=div, verbose=True) + ) + + np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) + np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) + + +@pytest.mark.parametrize("div", ["kl", "l2"]) +def test_mm_reference_measure(nx, div): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + c = a[:, None] * b[None, :] + + reg = 1e-2 + reg_m = 100 + + G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=None, reg=reg, + div=div, verbose=False, log=True) + loss_0 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=None, reg=reg, + div=div, verbose=True) + loss_0 = nx.to_numpy(loss_0) + + G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=c, + reg=reg, div=div, + verbose=False, log=True) + loss_1 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=c, + reg=reg, div=div, verbose=True) + loss_1 = nx.to_numpy(loss_1) + + np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) + np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) + + +def test_mm_wrong_divergence(nx): + + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + + reg = 1e-2 + reg_m = 100 + + def mm_div(div): + return ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, reg=reg, + div=div, verbose=False, log=True) + + def mm2_div(div): + return ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, reg=reg, + div=div, verbose=True) + + np.testing.assert_raises(ValueError, mm_div, "div_not_existed") + np.testing.assert_raises(ValueError, mm2_div, "div_not_existed") diff --git a/test/test_unbalanced.py b/test/unbalanced/test_sinkhorn.py similarity index 55% rename from test/test_unbalanced.py rename to test/unbalanced/test_sinkhorn.py index 7007e336b..4d32660fd 100644 --- a/test/test_unbalanced.py +++ b/test/unbalanced/test_sinkhorn.py @@ -14,8 +14,8 @@ from ot.unbalanced import barycenter_unbalanced -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) -def test_unbalanced_convergence(nx, method, reg_type): +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"]) +def test_unbalanced_convergence(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -30,27 +30,26 @@ def test_unbalanced_convergence(nx, method, reg_type): epsilon = 1. reg_m = 1. + stopThr = 1e-12 G, log = ot.unbalanced.sinkhorn_unbalanced( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, log=True, verbose=True + c=None, log=True, verbose=True, numItermax=1000, stopThr=stopThr ) loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, verbose=True + c=None, verbose=True, numItermax=1000, stopThr=stopThr )) # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) + logb = nx.log(b + 1e-16) loga = nx.log(a + 1e-16) - if reg_type == "entropy": - logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) - logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1) - elif reg_type == "kl": - log_ab = loga[:, None] + logb[None, :] - logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon + log_ab.T, axis=1) - logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon + log_ab, axis=1) + log_ab = loga[:, None] + logb[None, :] + + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon + log_ab.T, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon + log_ab, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) @@ -69,17 +68,19 @@ def test_unbalanced_convergence(nx, method, reg_type): G = ot.unbalanced.sinkhorn_unbalanced( a, b, M, reg=epsilon, reg_m=reg_m, - method=method, reg_type=reg_type, verbose=True + method=method, c=None, verbose=True, + stopThr=stopThr ) G_np = ot.unbalanced.sinkhorn_unbalanced( a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, - method=method, reg_type=reg_type, verbose=True + method=method, c=None, verbose=True, + stopThr=stopThr ) np.testing.assert_allclose(G_np, nx.to_numpy(G)) -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) -def test_unbalanced_warmstart(nx, method, reg_type): +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"]) +def test_unbalanced_warmstart(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -95,33 +96,33 @@ def test_unbalanced_warmstart(nx, method, reg_type): G0, log0 = ot.unbalanced.sinkhorn_unbalanced( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=None, log=True, verbose=True + c=None, warmstart=None, log=True, verbose=True ) loss0 = ot.unbalanced.sinkhorn_unbalanced2( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=None, verbose=True + c=None, warmstart=None, verbose=True ) dim_a, dim_b = M.shape warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)) G, log = ot.unbalanced.sinkhorn_unbalanced( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=warmstart, log=True, verbose=True + c=None, warmstart=warmstart, log=True, verbose=True ) loss = ot.unbalanced.sinkhorn_unbalanced2( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=warmstart, verbose=True + c=None, warmstart=warmstart, verbose=True ) _, log_emd = ot.lp.emd(a, b, M, log=True) warmstart1 = (log_emd["u"], log_emd["v"]) G1, log1 = ot.unbalanced.sinkhorn_unbalanced( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=warmstart1, log=True, verbose=True + c=None, warmstart=warmstart1, log=True, verbose=True ) loss1 = ot.unbalanced.sinkhorn_unbalanced2( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=warmstart1, verbose=True + c=None, warmstart=warmstart1, verbose=True ) np.testing.assert_allclose( @@ -140,8 +141,8 @@ def test_unbalanced_warmstart(nx, method, reg_type): np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5) -@pytest.mark.parametrize("method,reg_type, log", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"], [True, False])) -def test_sinkhorn_unbalanced2(nx, method, reg_type, log): +@pytest.mark.parametrize("method, log", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], [True, False])) +def test_sinkhorn_unbalanced2(nx, method, log): n = 100 rng = np.random.RandomState(42) @@ -158,12 +159,12 @@ def test_sinkhorn_unbalanced2(nx, method, reg_type, log): loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, log=False, verbose=True + c=None, log=False, verbose=True )) res = ot.unbalanced.sinkhorn_unbalanced2( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, log=log, verbose=True + c=None, log=log, verbose=True ) loss0 = res[0] if log else res @@ -248,6 +249,11 @@ def test_unbalanced_multiple_inputs(nx, method): v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) + print("u_final shape = {}".format(u_final.shape)) + print("v_final shape = {}".format(v_final.shape)) + print("logu shape = {}".format(log["logu"].shape)) + print("logv shape = {}".format(log["logv"].shape)) + np.testing.assert_allclose( nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( @@ -268,19 +274,20 @@ def test_stabilized_vs_sinkhorn(nx): M = ot.utils.dist0(n) M /= np.median(M) - epsilon = 0.1 + epsilon = 1 reg_m = 1. + stopThr = 1e-12 ab, bb, Mb = nx.from_numpy(a, b, M) G, _ = ot.unbalanced.sinkhorn_unbalanced2( - ab, bb, Mb, epsilon, reg_m, method="sinkhorn_stabilized", log=True + ab, bb, Mb, epsilon, reg_m, method="sinkhorn_stabilized", log=True, stopThr=stopThr, ) G2, _ = ot.unbalanced.sinkhorn_unbalanced2( - ab, bb, Mb, epsilon, reg_m, method="sinkhorn", log=True + ab, bb, Mb, epsilon, reg_m, method="sinkhorn", log=True, stopThr=stopThr ) G2_np, _ = ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, epsilon, reg_m, method="sinkhorn", log=True + a, b, M, epsilon, reg_m, method="sinkhorn", log=True, stopThr=stopThr ) G = nx.to_numpy(G) G2 = nx.to_numpy(G2) @@ -428,247 +435,3 @@ def test_implemented_methods(nx): method=method) barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, method=method) - - -@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) -def test_lbfgsb_unbalanced(nx, reg_div, regm_div): - - np.random.seed(42) - - xs = np.random.randn(5, 2) - xt = np.random.randn(6, 2) - - M = ot.dist(xs, xt) - - a = ot.unif(5) - b = ot.unif(6) - - G, log = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False) - - ab, bb, Mb = nx.from_numpy(a, b, M) - - Gb, log = ot.unbalanced.lbfgsb_unbalanced(ab, bb, Mb, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False) - - np.testing.assert_allclose(G, nx.to_numpy(Gb)) - - -@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) -def test_lbfgsb_unbalanced_relaxation_parameters(nx, reg_div, regm_div): - - np.random.seed(42) - - xs = np.random.randn(5, 2) - xt = np.random.randn(6, 2) - - M = ot.dist(xs, xt) - - a = ot.unif(5) - b = ot.unif(6) - - a, b, M = nx.from_numpy(a, b, M) - - reg_m = 10 - full_list_reg_m = [reg_m, reg_m] - full_tuple_reg_m = (reg_m, reg_m) - tuple_reg_m, list_reg_m = (reg_m), [reg_m] - np1_reg_m = reg_m * np.ones(1) - np2_reg_m = reg_m * np.ones(2) - - list_options = [np1_reg_m, np2_reg_m, full_tuple_reg_m, - tuple_reg_m, full_list_reg_m, list_reg_m] - - G = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, reg_m=reg_m, - reg_div=reg_div, regm_div=regm_div, - log=False, verbose=False) - - for opt in list_options: - G0 = ot.unbalanced.lbfgsb_unbalanced( - a, b, M, 1, reg_m=opt, reg_div=reg_div, - regm_div=regm_div, log=False, verbose=False - ) - - np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) - - -@pytest.mark.parametrize("reg_div,regm_div", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'])) -def test_lbfgsb_reference_measure(nx, reg_div, regm_div): - - np.random.seed(42) - - xs = np.random.randn(5, 2) - xt = np.random.randn(6, 2) - M = ot.dist(xs, xt) - a = ot.unif(5) - b = ot.unif(6) - - a, b, M = nx.from_numpy(a, b, M) - c = a[:, None] * b[None, :] - - G, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=None, - reg_div=reg_div, regm_div=regm_div, - log=True, verbose=False) - - G0, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=c, - reg_div=reg_div, regm_div=regm_div, - log=True, verbose=False) - - np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) - - -@pytest.mark.parametrize("div", ["kl", "l2"]) -def test_mm_convergence(nx, div): - n = 100 - rng = np.random.RandomState(42) - x = rng.randn(n, 2) - rng = np.random.RandomState(75) - y = rng.randn(n, 2) - a_np = ot.utils.unif(n) - b_np = ot.utils.unif(n) - - M = ot.dist(x, y) - M = M / M.max() - reg_m = 100 - a, b, M = nx.from_numpy(a_np, b_np, M) - - G, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div=div, - verbose=False, log=True) - loss = nx.to_numpy( - ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div=div, verbose=True) - ) - - # check if the marginals come close to the true ones when large reg - np.testing.assert_allclose(np.sum(nx.to_numpy(G), 1), a_np, atol=1e-03) - np.testing.assert_allclose(np.sum(nx.to_numpy(G), 0), b_np, atol=1e-03) - - # check if mm_unbalanced2 returns the correct loss - np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), loss, atol=1e-5) - - # check in case no histogram is provided - a_np, b_np = np.array([]), np.array([]) - a, b = nx.from_numpy(a_np, b_np) - - G_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div=div, verbose=False) - np.testing.assert_allclose(nx.to_numpy(G_null), nx.to_numpy(G)) - - # test when G0 is given - G0 = ot.emd(a, b, M) - G0_np = nx.to_numpy(G0) - reg_m = 10000 - G = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div=div, G0=G0, verbose=False) - np.testing.assert_allclose(G0_np, nx.to_numpy(G), atol=1e-05) - - -@pytest.mark.parametrize("div", ["kl", "l2"]) -def test_mm_relaxation_parameters(nx, div): - n = 100 - rng = np.random.RandomState(42) - x = rng.randn(n, 2) - rng = np.random.RandomState(75) - y = rng.randn(n, 2) - a_np = ot.utils.unif(n) - b_np = ot.utils.unif(n) - - M = ot.dist(x, y) - M = M / M.max() - a, b, M = nx.from_numpy(a_np, b_np, M) - - reg = 1e-2 - - reg_m = 100 - full_list_reg_m = [reg_m, reg_m] - full_tuple_reg_m = (reg_m, reg_m) - tuple_reg_m, list_reg_m = (reg_m), [reg_m] - nx1_reg_m = reg_m * nx.ones(1) - nx2_reg_m = reg_m * nx.ones(2) - - list_options = [nx1_reg_m, nx2_reg_m, full_tuple_reg_m, - tuple_reg_m, full_list_reg_m, list_reg_m] - - G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, reg=reg, - div=div, verbose=False, log=True) - loss_0 = nx.to_numpy( - ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, reg=reg, - div=div, verbose=True) - ) - - for opt in list_options: - G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=opt, - reg=reg, div=div, - verbose=False, log=True) - loss_1 = nx.to_numpy( - ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=opt, - reg=reg, div=div, verbose=True) - ) - - np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) - np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) - - -@pytest.mark.parametrize("div", ["kl", "l2"]) -def test_mm_reference_measure(nx, div): - n = 100 - rng = np.random.RandomState(42) - x = rng.randn(n, 2) - rng = np.random.RandomState(75) - y = rng.randn(n, 2) - a_np = ot.utils.unif(n) - b_np = ot.utils.unif(n) - - M = ot.dist(x, y) - M = M / M.max() - a, b, M = nx.from_numpy(a_np, b_np, M) - c = a[:, None] * b[None, :] - - reg = 1e-2 - reg_m = 100 - - G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=None, reg=reg, - div=div, verbose=False, log=True) - loss_0 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=None, reg=reg, - div=div, verbose=True) - loss_0 = nx.to_numpy(loss_0) - - G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=c, - reg=reg, div=div, - verbose=False, log=True) - loss_1 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=c, - reg=reg, div=div, verbose=True) - loss_1 = nx.to_numpy(loss_1) - - np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) - np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) - - -def test_mm_wrong_divergence(nx): - - n = 100 - rng = np.random.RandomState(42) - x = rng.randn(n, 2) - rng = np.random.RandomState(75) - y = rng.randn(n, 2) - a_np = ot.utils.unif(n) - b_np = ot.utils.unif(n) - - M = ot.dist(x, y) - M = M / M.max() - a, b, M = nx.from_numpy(a_np, b_np, M) - - reg = 1e-2 - reg_m = 100 - - G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, reg=reg, - div="kl", verbose=False, log=True) - loss_0 = nx.to_numpy( - ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, reg=reg, - div="kl", verbose=True) - ) - - G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, reg=reg, - div="wrong_div", verbose=False, log=True) - loss_1 = nx.to_numpy( - ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, reg=reg, - div="wrong_div", verbose=True) - ) - - np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) - np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) From 4b01c6b0c356f8acdc86a8c83d933f44ec296cc0 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Tue, 16 Jul 2024 17:28:14 +0200 Subject: [PATCH 02/17] remove notebook --- ot/unbalanced/Untitled.ipynb | 973 ----------------------------------- 1 file changed, 973 deletions(-) delete mode 100644 ot/unbalanced/Untitled.ipynb diff --git a/ot/unbalanced/Untitled.ipynb b/ot/unbalanced/Untitled.ipynb deleted file mode 100644 index 87737c676..000000000 --- a/ot/unbalanced/Untitled.ipynb +++ /dev/null @@ -1,973 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 232, - "id": "bab1ddba", - "metadata": {}, - "outputs": [], - "source": [ - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "fa20df35", - "metadata": {}, - "outputs": [], - "source": [ - "x = torch.ones(3, 3)\n", - "y = 2 * x\n", - "z = 3 * x\n", - "t = torch.stack([y, z], dim=2)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "e58f96d3", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([3, 3, 2])" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "0b1cc563", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([3, 3])" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "7ec187d3", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[2., 3.],\n", - " [2., 3.],\n", - " [2., 3.]],\n", - "\n", - " [[2., 3.],\n", - " [2., 3.],\n", - " [2., 3.]],\n", - "\n", - " [[2., 3.],\n", - " [2., 3.],\n", - " [2., 3.]]])" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x[:, :, None] * t" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "041cc43d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1., 1., 1.],\n", - " [1., 1., 1.],\n", - " [1., 1., 1.]])" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "03d039f4", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[2., 3.],\n", - " [2., 3.],\n", - " [2., 3.]],\n", - "\n", - " [[2., 3.],\n", - " [2., 3.],\n", - " [2., 3.]],\n", - "\n", - " [[2., 3.],\n", - " [2., 3.],\n", - " [2., 3.]]])" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t" - ] - }, - { - "cell_type": "code", - "execution_count": 390, - "id": "6c50108f", - "metadata": {}, - "outputs": [], - "source": [ - "a = torch.ones(3, 1)\n", - "b = torch.ones(5) * 2\n", - "c = torch.ones(5) * 3\n", - "d = torch.stack([b, c],dim=1)\n", - "m = torch.randn(3, 5)\n", - "\n", - "aa = torch.concat([a, a],dim=1)" - ] - }, - { - "cell_type": "code", - "execution_count": 303, - "id": "14167c8d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[2., 2., 2., 2., 2.],\n", - " [2., 2., 2., 2., 2.],\n", - " [2., 2., 2., 2., 2.]])" - ] - }, - "execution_count": 303, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a @ b[None, :]" - ] - }, - { - "cell_type": "code", - "execution_count": 304, - "id": "67f0db6c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.]])" - ] - }, - "execution_count": 304, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "d" - ] - }, - { - "cell_type": "code", - "execution_count": 305, - "id": "924d115d", - "metadata": {}, - "outputs": [], - "source": [ - "e = torch.einsum('ij, jkm->mik', a, d[None, :, :])" - ] - }, - { - "cell_type": "code", - "execution_count": 306, - "id": "c32ee8cc", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[2., 2., 2., 2., 2.],\n", - " [2., 2., 2., 2., 2.],\n", - " [2., 2., 2., 2., 2.]],\n", - "\n", - " [[3., 3., 3., 3., 3.],\n", - " [3., 3., 3., 3., 3.],\n", - " [3., 3., 3., 3., 3.]]])" - ] - }, - "execution_count": 306, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "e" - ] - }, - { - "cell_type": "code", - "execution_count": 307, - "id": "88abc5fb", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[2., 2., 2., 2., 2.],\n", - " [3., 3., 3., 3., 3.]])" - ] - }, - "execution_count": 307, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dd" - ] - }, - { - "cell_type": "code", - "execution_count": 308, - "id": "5fbd8e4b", - "metadata": {}, - "outputs": [ - { - "ename": "RuntimeError", - "evalue": "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_17575/3808899050.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mdd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;34m(\u001b[0m\u001b[0mdd\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1" - ] - } - ], - "source": [ - "dd = d.T\n", - "(dd * e).shape" - ] - }, - { - "cell_type": "code", - "execution_count": 309, - "id": "b6dd203c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.]])" - ] - }, - "execution_count": 309, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "d" - ] - }, - { - "cell_type": "code", - "execution_count": 310, - "id": "aa0e4d3c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[20., 45.],\n", - " [20., 45.],\n", - " [20., 45.]])" - ] - }, - "execution_count": 310, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "f = torch.einsum('ijk, ki->ji', e, d)\n", - "f" - ] - }, - { - "cell_type": "code", - "execution_count": 289, - "id": "07c71321", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[2., 2., 2., 2., 2.],\n", - " [2., 2., 2., 2., 2.]],\n", - "\n", - " [[3., 3., 3., 3., 3.],\n", - " [3., 3., 3., 3., 3.]]])" - ] - }, - "execution_count": 289, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "e" - ] - }, - { - "cell_type": "code", - "execution_count": 290, - "id": "0d7076ce", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.]])" - ] - }, - "execution_count": 290, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "d" - ] - }, - { - "cell_type": "code", - "execution_count": 320, - "id": "2ad6c8a0", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[6., 9.],\n", - " [6., 9.],\n", - " [6., 9.],\n", - " [6., 9.],\n", - " [6., 9.]])" - ] - }, - "execution_count": 320, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eT = e.reshape(2, 5, 3)\n", - "g = torch.einsum('ijk, ki->ji', eT, a)\n", - "g" - ] - }, - { - "cell_type": "code", - "execution_count": 321, - "id": "496709df", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[0.3333, 0.3333],\n", - " [0.3333, 0.3333],\n", - " [0.3333, 0.3333],\n", - " [0.3333, 0.3333],\n", - " [0.3333, 0.3333]])" - ] - }, - "execution_count": 321, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "d / g" - ] - }, - { - "cell_type": "code", - "execution_count": 324, - "id": "f5cb0a5b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.]])" - ] - }, - "execution_count": 324, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "d" - ] - }, - { - "cell_type": "code", - "execution_count": 329, - "id": "20d8b409", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([10.8688, 24.4548])" - ] - }, - "execution_count": 329, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "res = torch.einsum('ik,kij,jk,ij->k', a, e, d, m)\n", - "res" - ] - }, - { - "cell_type": "code", - "execution_count": 330, - "id": "5b419e6a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1.],\n", - " [1.],\n", - " [1.]])" - ] - }, - "execution_count": 330, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a" - ] - }, - { - "cell_type": "code", - "execution_count": 331, - "id": "84cf81ac", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[2., 2., 2., 2., 2.],\n", - " [2., 2., 2., 2., 2.],\n", - " [2., 2., 2., 2., 2.]],\n", - "\n", - " [[3., 3., 3., 3., 3.],\n", - " [3., 3., 3., 3., 3.],\n", - " [3., 3., 3., 3., 3.]]])" - ] - }, - "execution_count": 331, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "e" - ] - }, - { - "cell_type": "code", - "execution_count": 332, - "id": "1521a955", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.]])" - ] - }, - "execution_count": 332, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "d" - ] - }, - { - "cell_type": "code", - "execution_count": 365, - "id": "a84c2925", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.]],\n", - "\n", - " [[1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1.]]])" - ] - }, - "execution_count": 365, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "(a[None, :,:] + d.T[:,None, :])" - ] - }, - { - "cell_type": "code", - "execution_count": 340, - "id": "346847db", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1.],\n", - " [1.],\n", - " [1.]])" - ] - }, - "execution_count": 340, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a" - ] - }, - { - "cell_type": "code", - "execution_count": 341, - "id": "da1574b9", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.]])" - ] - }, - "execution_count": 341, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "d" - ] - }, - { - "cell_type": "code", - "execution_count": 361, - "id": "5c25b074", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[3., 3., 3., 3., 3.],\n", - " [3., 3., 3., 3., 3.],\n", - " [3., 3., 3., 3., 3.]])" - ] - }, - "execution_count": 361, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a + d[:, 0]" - ] - }, - { - "cell_type": "code", - "execution_count": 366, - "id": "f10f2752", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.]])" - ] - }, - "execution_count": 366, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "d" - ] - }, - { - "cell_type": "code", - "execution_count": 382, - "id": "4d48e54a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[4., 4., 4., 4., 4.],\n", - " [4., 4., 4., 4., 4.],\n", - " [4., 4., 4., 4., 4.],\n", - " [4., 4., 4., 4., 4.],\n", - " [4., 4., 4., 4., 4.]])" - ] - }, - "execution_count": 382, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "d[:, 0][:, None] + d[:, 0][None, :]" - ] - }, - { - "cell_type": "code", - "execution_count": 381, - "id": "80975d51", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[4., 4., 4., 4., 4.],\n", - " [4., 4., 4., 4., 4.],\n", - " [4., 4., 4., 4., 4.],\n", - " [4., 4., 4., 4., 4.],\n", - " [4., 4., 4., 4., 4.]],\n", - "\n", - " [[6., 6., 6., 6., 6.],\n", - " [6., 6., 6., 6., 6.],\n", - " [6., 6., 6., 6., 6.],\n", - " [6., 6., 6., 6., 6.],\n", - " [6., 6., 6., 6., 6.]]])" - ] - }, - "execution_count": 381, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ee = d.T[:, :, None] + d.T[:, None, :]\n", - "ee" - ] - }, - { - "cell_type": "code", - "execution_count": 377, - "id": "6e7f9913", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 5, 5])" - ] - }, - "execution_count": 377, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ee.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 378, - "id": "c4bc1337", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 3, 5])" - ] - }, - "execution_count": 378, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "e.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 379, - "id": "70cb3f8b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1.],\n", - " [1.],\n", - " [1.]])" - ] - }, - "execution_count": 379, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a" - ] - }, - { - "cell_type": "code", - "execution_count": 380, - "id": "a93d3f3f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.]])" - ] - }, - "execution_count": 380, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "d" - ] - }, - { - "cell_type": "code", - "execution_count": 391, - "id": "91f386ea", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1., 1.],\n", - " [1., 1.],\n", - " [1., 1.]])" - ] - }, - "execution_count": 391, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "aa" - ] - }, - { - "cell_type": "code", - "execution_count": 395, - "id": "99b2b52f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[2., 2., 2., 2., 2.],\n", - " [2., 2., 2., 2., 2.],\n", - " [2., 2., 2., 2., 2.]],\n", - "\n", - " [[3., 3., 3., 3., 3.],\n", - " [3., 3., 3., 3., 3.],\n", - " [3., 3., 3., 3., 3.]]])" - ] - }, - "execution_count": 395, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.einsum('ij, jkm->mik', a, d[None, :, :])" - ] - }, - { - "cell_type": "code", - "execution_count": 398, - "id": "aed12015", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[2., 2., 2., 2., 2.],\n", - " [2., 2., 2., 2., 2.],\n", - " [2., 2., 2., 2., 2.]])" - ] - }, - "execution_count": 398, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a * d[:, 0]" - ] - }, - { - "cell_type": "code", - "execution_count": 397, - "id": "bca65b4e", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.],\n", - " [2., 3.]])" - ] - }, - "execution_count": 397, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "d" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e6187bb9", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 9380ff30654b989ddf4ffe5b7ba95dcd2fd86a0f Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Tue, 16 Jul 2024 17:29:54 +0200 Subject: [PATCH 03/17] recover solver.py --- ot/solvers.py | 1386 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1386 insertions(+) create mode 100644 ot/solvers.py diff --git a/ot/solvers.py b/ot/solvers.py new file mode 100644 index 000000000..95165ea11 --- /dev/null +++ b/ot/solvers.py @@ -0,0 +1,1386 @@ +# -*- coding: utf-8 -*- +""" +General OT solvers with unified API +""" + +# Author: Remi Flamary +# +# License: MIT License + +from .utils import OTResult, dist +from .lp import emd2, wasserstein_1d +from .backend import get_backend +from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced +from .bregman import sinkhorn_log, empirical_sinkhorn2, empirical_sinkhorn2_geomloss +from .partial import partial_wasserstein_lagrange +from .smooth import smooth_ot_dual +from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2, + entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2, + semirelaxed_gromov_wasserstein2, semirelaxed_fused_gromov_wasserstein2, + entropic_semirelaxed_fused_gromov_wasserstein2, + entropic_semirelaxed_gromov_wasserstein2) +from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2 +from .gaussian import empirical_bures_wasserstein_distance +from .factored import factored_optimal_transport +from .lowrank import lowrank_sinkhorn +from .optim import cg + +lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale'] + + +def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, + unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None, + potentials_init=None, tol=None, verbose=False, grad='autodiff'): + r"""Solve the discrete optimal transport problem and return :any:`OTResult` object + + The function solves the following general optimal transport problem + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + M : array_like, shape (dim_a, dim_b) + Loss matrix + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", + by default "KL". a tuple of functions can be provided for general + solver (see :any:`cg`). This is only used when ``reg!=None``. + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT) + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", + "TV", by default "KL". + method : str, optional + Method for solving the problem when multiple algorithms are available, + default None for automatic selection. + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter : int, optional + Maximum number of iterations, by default None (default values in each solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional + Initialization of the OT dual potentials for iterative methods, by default None + tol : _type_, optional + Tolerance for solution precision, by default None (default values in each solvers) + verbose : bool, optional + Print information in the solver, by default False + grad : str, optional + Type of gradient computation, either or 'autodiff' or 'envelope' used only for + Sinkhorn solver. By default 'autodiff' provides gradients wrt all + outputs (`plan, value, value_linear`) but with important memory cost. + 'envelope' provides gradients only for `value` and and other outputs are + detached. This is useful for memory saving when only the value is needed. + + Returns + ------- + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + Notes + ----- + + The following methods are available for solving the OT problems: + + - **Classical exact OT problem [1]** (default parameters) : + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve(M, a, b) + + - **Entropic regularized OT [2]** (when ``reg!=None``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` regularization (``reg_type="KL"``) + res = ot.solve(M, a, b, reg=1.0) + # or for original Sinkhorn paper formulation [2] + res = ot.solve(M, a, b, reg=1.0, reg_type='entropy') + + # Use envelope theorem differentiation for memory saving + res = ot.solve(M, a, b, reg=1.0, grad='envelope') # M, a, b are torch tensors + res.value.backward() # only the value is differentiable + + Note that by default the Sinkhorn solver uses automatic differentiation to + compute the gradients of the values and plan. This can be changed with the + `grad` parameter. The `envelope` mode computes the gradients only + for the value and the other outputs are detached. This is useful for + memory saving when only the gradient of value is needed. + + - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve(M,a,b,reg=1.0,reg_type='L2') + + - **Unbalanced OT [41]** (when ``unbalanced!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` + res = ot.solve(M,a,b,unbalanced=1.0) + # quadratic unbalanced OT + res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='L2') + # TV = partial OT + res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='TV') + + + - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` for both + res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0) + # quadratic unbalanced OT with KL regularization + res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2') + # both quadratic + res = ot.solve(M,a,b,reg=1.0, reg_type='L2',unbalanced=1.0,unbalanced_type='L2') + + + .. _references-solve: + References + ---------- + + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. + (2011, December). Displacement interpolation using Lagrangian mass + transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. + 158). ACM. + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. + + .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse + Optimal Transport. Proceedings of the Twenty-First International + Conference on Artificial Intelligence and Statistics (AISTATS). + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, + A., & Peyré, G. (2019, April). Interpolating between optimal transport + and MMD using Sinkhorn divergences. In The 22nd International Conference + on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + + """ + + # detect backend + arr = [M] + if a is not None: + arr.append(a) + if b is not None: + arr.append(b) + nx = get_backend(*arr) + + # create uniform weights if not given + if a is None: + a = nx.ones(M.shape[0], type_as=M) / M.shape[0] + if b is None: + b = nx.ones(M.shape[1], type_as=M) / M.shape[1] + + # default values for solutions + potentials = None + value = None + value_linear = None + plan = None + status = None + + if reg is None or reg == 0: # exact OT + + if unbalanced is None: # Exact balanced OT + + # default values for EMD solver + if max_iter is None: + max_iter = 1000000 + + value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads) + + value = value_linear + potentials = (log['u'], log['v']) + plan = log['G'] + status = log["warning"] if log["warning"] is not None else 'Converged' + + elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT + + # default values for exact unbalanced OT + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-12 + + plan, log = mm_unbalanced(a, b, M, reg_m=unbalanced, + div=unbalanced_type.lower(), numItermax=max_iter, + stopThr=tol, log=True, + verbose=verbose, G0=plan_init) + + value_linear = log['cost'] + + if unbalanced_type.lower() == 'kl': + value = value_linear + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b)) + else: + err_a = nx.sum(plan, 1) - a + err_b = nx.sum(plan, 0) - b + value = value_linear + unbalanced * nx.sum(err_a**2) + unbalanced * nx.sum(err_b**2) + + elif unbalanced_type.lower() == 'tv': + + if max_iter is None: + max_iter = 1000000 + + plan, log = partial_wasserstein_lagrange(a, b, M, reg_m=unbalanced**2, log=True, numItermax=max_iter) + + value_linear = nx.sum(M * plan) + err_a = nx.sum(plan, 1) - a + err_b = nx.sum(plan, 0) - b + value = value_linear + nx.sqrt(unbalanced**2 / 2.0 * (nx.sum(nx.abs(err_a)) + + nx.sum(nx.abs(err_b)))) + + else: + raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type))) + + else: # regularized OT + + if unbalanced is None: # Balanced regularized OT + + if isinstance(reg_type, tuple): # general solver + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + plan, log = cg(a, b, M, reg=reg, f=reg_type[0], df=reg_type[1], numItermax=max_iter, stopThr=tol, log=True, verbose=verbose, G0=plan_init) + + value_linear = nx.sum(M * plan) + value = log['loss'][-1] + potentials = (log['u'], log['v']) + + elif reg_type.lower() in ['entropy', 'kl']: + + if grad == 'envelope': # if envelope then detach the input + M0, a0, b0 = M, a, b + M, a, b = nx.detach(M, a, b) + + # default values for sinkhorn + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter, + stopThr=tol, log=True, + verbose=verbose) + + value_linear = nx.sum(M * plan) + + if reg_type.lower() == 'entropy': + value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) + else: + value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) + + potentials = (log['log_u'], log['log_v']) + + if grad == 'envelope': # set the gradient at convergence + + value = nx.set_gradients(value, (M0, a0, b0), + (plan, reg * (potentials[0] - potentials[0].mean()), reg * (potentials[1] - potentials[1].mean()))) + + elif reg_type.lower() == 'l2': + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + plan, log = smooth_ot_dual(a, b, M, reg=reg, numItermax=max_iter, stopThr=tol, log=True, verbose=verbose) + + value_linear = nx.sum(M * plan) + value = value_linear + reg * nx.sum(plan**2) + potentials = (log['alpha'], log['beta']) + + else: + raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type))) + + else: # unbalanced AND regularized OT + + if not isinstance(reg_type, tuple) and reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl': + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + plan, log = sinkhorn_knopp_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, numItermax=max_iter, stopThr=tol, verbose=verbose, log=True) + + value_linear = nx.sum(M * plan) + + value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b)) + + potentials = (log['logu'], log['logv']) + + elif (isinstance(reg_type, tuple) or reg_type.lower() in ['kl', 'l2', 'entropy']) and unbalanced_type.lower() in ['kl', 'l2', 'tv']: + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-12 + if isinstance(reg_type, str): + reg_type = reg_type.lower() + + plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type, regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True, G0=plan_init) + + value_linear = nx.sum(M * plan) + + value = log['loss'] + + else: + raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) + + res = OTResult(potentials=potentials, value=value, + value_linear=value_linear, plan=plan, status=status, backend=nx) + + return res + + +def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, + alpha=0.5, reg=None, + reg_type="entropy", unbalanced=None, unbalanced_type='KL', + n_threads=1, method=None, max_iter=None, plan_init=None, tol=None, + verbose=False): + r""" Solve the discrete (Fused) Gromov-Wasserstein and return :any:`OTResult` object + + The function solves the following optimization problem: + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with `reg` (:math:`\lambda_r`) and + `reg_type`. By default ``reg=None`` and there is no regularization. The + unbalanced marginal penalization can be selected with `unbalanced` + (:math:`\lambda_u`) and `unbalanced_type`. By default ``unbalanced=None`` + and the function solves the exact optimal transport problem (respecting the + marginals). + + Parameters + ---------- + Ca : array_like, shape (dim_a, dim_a) + Cost matrix in the source domain + Cb : array_like, shape (dim_b, dim_b) + Cost matrix in the target domain + M : array_like, shape (dim_a, dim_b), optional + Linear cost matrix for Fused Gromov-Wasserstein (default is None). + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + loss : str, optional + Type of loss function, either ``"L2"`` or ``"KL"``, by default ``"L2"`` + symmetric : bool, optional + Use symmetric version of the Gromov-Wasserstein problem, by default None + tests whether the matrices are symmetric or True/False to avoid the test. + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R`, by default "entropy" (only used when + ``reg!=None``) + alpha : float, optional + Weight the quadratic term (alpha*Gromov) and the linear term + ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for + Gromov problem (when M is not provided). By default ``alpha=None`` + corresponds to ``alpha=1`` for Gromov problem (``M==None``) and + ``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``) + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT), Not implemented yet + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "semirelaxed", + "partial", by default "KL" but note that it is not implemented yet. + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + method : str, optional + Method for solving the problem when multiple algorithms are available, + default None for automatic selection. + max_iter : int, optional + Maximum number of iterations, by default None (default values in each + solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + tol : float, optional + Tolerance for solution precision, by default None (default values in + each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + - res.value_quad : Quadratic (GW) part of the OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + Notes + ----- + The following methods are available for solving the Gromov-Wasserstein + problem: + + - **Classical Gromov-Wasserstein (GW) problem [3]** (default parameters): + + .. math:: + \min_{\mathbf{T}\geq 0} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb) # uniform weights + res = ot.solve_gromov(Ca, Cb, a=a, b=b) # given weights + res = ot.solve_gromov(Ca, Cb, loss='KL') # KL loss + + plan = res.plan # GW plan + value = res.value # GW value + + - **Fused Gromov-Wasserstein (FGW) problem [24]** (when ``M!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb, M) # uniform weights, alpha=0.5 (default) + res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, alpha=0.1) # given weights and alpha + + plan = res.plan # FGW plan + loss_linear_term = res.value_linear # Wasserstein part of the loss + loss_quad_term = res.value_quad # Gromov part of the loss + loss = res.value # FGW value + + - **Regularized (Fused) Gromov-Wasserstein (GW) problem [12]** (when ``reg!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + \lambda_r R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb, reg=1.0) # GW entropy regularization (default) + res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, reg=10, alpha=0.1) # FGW with entropy + + plan = res.plan # FGW plan + loss_linear_term = res.value_linear # Wasserstein part of the loss + loss_quad_term = res.value_quad # Gromov part of the loss + loss = res.value # FGW value (including regularization) + + - **Semi-relaxed (Fused) Gromov-Wasserstein (GW) [48]** (when ``unbalanced='semirelaxed'``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed') # semirelaxed GW + res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed', reg=1) # entropic semirelaxed GW + res = ot.solve_gromov(Ca, Cb, M, unbalanced='semirelaxed', alpha=0.1) # semirelaxed FGW + + plan = res.plan # FGW plan + right_marginal = res.marginal_b # right marginal of the plan + + - **Partial (Fused) Gromov-Wasserstein (GW) problem [29]** (when ``unbalanced='partial'``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} \leq \mathbf{a} + + \mathbf{T}^T \mathbf{1} \leq \mathbf{b} + + \mathbf{T} \geq 0 + + \mathbf{1}^T\mathbf{T}\mathbf{1} = m + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=0.8) # partial GW with m=0.8 + + + .. _references-solve-gromov: + References + ---------- + + .. [3] Mémoli, F. (2011). Gromov–Wasserstein distances and the metric + approach to object matching. Foundations of computational mathematics, + 11(4), 417-487. + + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), + Gromov-Wasserstein averaging of kernel and distance matrices + International Conference on Machine Learning (ICML). + + .. [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. + (2019). Optimal Transport for structured data with application on graphs + Proceedings of the 36th International Conference on Machine Learning + (ICML). + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, + Nicolas Courty (2022). Semi-relaxed Gromov-Wasserstein divergence and + applications on graphs. International Conference on Learning + Representations (ICLR), 2022. + + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). Partial Optimal Transport + with Applications on Positive-Unlabeled Learning, Advances in Neural + Information Processing Systems (NeurIPS), 2020. + + """ + + # detect backend + nx = get_backend(Ca, Cb, M, a, b) + + # create uniform weights if not given + if a is None: + a = nx.ones(Ca.shape[0], type_as=Ca) / Ca.shape[0] + if b is None: + b = nx.ones(Cb.shape[1], type_as=Cb) / Cb.shape[1] + + # default values for solutions + potentials = None + value = None + value_linear = None + value_quad = None + plan = None + status = None + log = None + + loss_dict = {'l2': 'square_loss', 'kl': 'kl_loss'} + + if loss.lower() not in loss_dict.keys(): + raise (NotImplementedError('Not implemented GW loss="{}"'.format(loss))) + loss_fun = loss_dict[loss.lower()] + + if reg is None or reg == 0: # exact OT + + if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']: # Exact balanced OT + + if M is None or alpha == 1: # Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = gromov_wasserstein2(Ca, Cb, a, b, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_quad = value + if alpha == 1: # set to 0 for FGW with alpha=1 + value_linear = 0 + plan = log['T'] + potentials = (log['u'], log['v']) + + elif alpha == 0: # Wasserstein problem + + # default values for EMD solver + if max_iter is None: + max_iter = 1000000 + + value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads) + + value = value_linear + potentials = (log['u'], log['v']) + plan = log['G'] + status = log["warning"] if log["warning"] is not None else 'Converged' + value_quad = 0 + + else: # Fused Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_linear = log['lin_loss'] + value_quad = log['quad_loss'] + plan = log['T'] + potentials = (log['u'], log['v']) + + elif unbalanced_type.lower() in ['semirelaxed']: # Semi-relaxed OT + + if M is None or alpha == 1: # Semi relaxed Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = semirelaxed_gromov_wasserstein2(Ca, Cb, a, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_quad = value + if alpha == 1: # set to 0 for FGW with alpha=1 + value_linear = 0 + plan = log['T'] + # potentials = (log['u'], log['v']) TODO + + else: # Semi relaxed Fused Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_linear = log['lin_loss'] + value_quad = log['quad_loss'] + plan = log['T'] + # potentials = (log['u'], log['v']) TODO + + elif unbalanced_type.lower() in ['partial']: # Partial OT + + if M is None: # Partial Gromov-Wasserstein problem + + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise (ValueError('Partial GW mass given in reg is too large')) + if loss.lower() != 'l2': + raise (NotImplementedError('Partial GW only implemented with L2 loss')) + if symmetric is not None: + raise (NotImplementedError('Partial GW only implemented with symmetric=True')) + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-7 + + value, log = partial_gromov_wasserstein2(Ca, Cb, a, b, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose) + + value_quad = value + plan = log['T'] + # potentials = (log['u'], log['v']) TODO + + else: # partial FGW + + raise (NotImplementedError('Partial FGW not implemented yet')) + + elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT + + raise (NotImplementedError('Unbalanced_type="{}"'.format(unbalanced_type))) + + else: + raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type))) + + else: # regularized OT + + if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']: # Balanced regularized OT + + if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if method is None: + method = 'PGD' + + value_quad, log = entropic_gromov_wasserstein2(Ca, Cb, a, b, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + plan = log['T'] + value_linear = 0 + value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) + # potentials = (log['log_u'], log['log_v']) #TODO + + elif reg_type.lower() in ['entropy'] and M is not None and alpha == 0: # Entropic Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter, + stopThr=tol, log=True, + verbose=verbose) + + value_linear = nx.sum(M * plan) + value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) + potentials = (log['log_u'], log['log_v']) + + elif reg_type.lower() in ['entropy'] and M is not None: # Entropic Fused Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if method is None: + method = 'PGD' + + value_noreg, log = entropic_fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_linear = log['lin_loss'] + value_quad = log['quad_loss'] + plan = log['T'] + # potentials = (log['u'], log['v']) + value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) + + else: + raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type))) + + elif unbalanced_type.lower() in ['semirelaxed']: # Semi-relaxed OT + + if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Semi-relaxed Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + value_quad, log = entropic_semirelaxed_gromov_wasserstein2(Ca, Cb, a, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + plan = log['T'] + value_linear = 0 + value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) + + else: # Entropic Semi-relaxed FGW problem + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_linear = log['lin_loss'] + value_quad = log['quad_loss'] + plan = log['T'] + value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) + + elif unbalanced_type.lower() in ['partial']: # Partial OT + + if M is None: # Partial Gromov-Wasserstein problem + + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise (ValueError('Partial GW mass given in reg is too large')) + if loss.lower() != 'l2': + raise (NotImplementedError('Partial GW only implemented with L2 loss')) + if symmetric is not None: + raise (NotImplementedError('Partial GW only implemented with symmetric=True')) + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-7 + + value_quad, log = entropic_partial_gromov_wasserstein2(Ca, Cb, a, b, reg=reg, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose) + + value_quad = value + plan = log['T'] + # potentials = (log['u'], log['v']) TODO + + else: # partial FGW + + raise (NotImplementedError('Partial entropic FGW not implemented yet')) + + else: # unbalanced AND regularized OT + + raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) + + res = OTResult(potentials=potentials, value=value, + value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx, log=log) + + return res + + +def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", + unbalanced=None, + unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95, + potentials_init=None, X_init=None, tol=None, verbose=False, + grad='autodiff'): + r"""Solve the discrete optimal transport problem using the samples in the source and target domains. + + The function solves the following general optimal transport problem + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + where the cost matrix :math:`\mathbf{M}` is computed from the samples in the + source and target domains such that :math:`M_{i,j} = d(x_i,y_j)` where + :math:`d` is a metric (by default the squared Euclidean distance). + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT) + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by + default False + batch_size : int, optional + Batch size for lazy solver, by default None (default values in each + solvers) + method : str, optional + Method for solving the problem, this can be used to select the solver + for unbalanced problems (see :any:`ot.solve`), or to select a specific + large scale solver. + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + rank : int, optional + Rank of the OT matrix for lazy solers (method='factored'), by default 100 + scaling : float, optional + Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95 + potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional + Initialization of the OT dual potentials for iterative methods, by default None + tol : _type_, optional + Tolerance for solution precision, by default None (default values in each solvers) + verbose : bool, optional + Print information in the solver, by default False + grad : str, optional + Type of gradient computation, either or 'autodiff' or 'envelope' used only for + Sinkhorn solver. By default 'autodiff' provides gradients wrt all + outputs (`plan, value, value_linear`) but with important memory cost. + 'envelope' provides gradients only for `value` and and other outputs are + detached. This is useful for memory saving when only the value is needed. + + Returns + ------- + + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + - res.lazy_plan : Lazy OT plan (when ``lazy=True`` or lazy method) + + See :any:`OTResult` for more information. + + Notes + ----- + + The following methods are available for solving the OT problems: + + - **Classical exact OT problem [1]** (default parameters) : + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_sample(xa, xb, a, b) + + # for uniform weights + res = ot.solve_sample(xa, xb) + + - **Entropic regularized OT [2]** (when ``reg!=None``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` regularization (``reg_type="KL"``) + res = ot.solve_sample(xa, xb, a, b, reg=1.0) + # or for original Sinkhorn paper formulation [2] + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='entropy') + + # lazy solver of memory complexity O(n) + res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100) + # lazy OT plan + lazy_plan = res.lazy_plan + + # Use envelope theorem differentiation for memory saving + res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='envelope') + res.value.backward() # only the value is differentiable + + Note that by default the Sinkhorn solver uses automatic differentiation to + compute the gradients of the values and plan. This can be changed with the + `grad` parameter. The `envelope` mode computes the gradients only + for the value and the other outputs are detached. This is useful for + memory saving when only the gradient of value is needed. + + We also have a very efficient solver with compiled CPU/CUDA code using + geomloss/PyKeOps that can be used with the following code: + + .. code-block:: python + + # automatic solver + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss') + + # force O(n) memory efficient solver + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_online') + + # force pre-computed cost matrix + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_tensorized') + + # use multiscale solver + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_multiscale') + + # One can play with speed (small scaling factor) and precision (scaling close to 1) + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss', scaling=0.5) + + - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2') + + - **Unbalanced OT [41]** (when ``unbalanced!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + with M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0) + # quadratic unbalanced OT + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='L2') + # TV = partial OT + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='TV') + + + - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + with M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` for both + res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0) + # quadratic unbalanced OT with KL regularization + res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0,unbalanced_type='L2') + # both quadratic + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2', + unbalanced=1.0, unbalanced_type='L2') + + + - **Factored OT [2]** (when ``method='factored'``): + + This method solve the following OT problem [40]_ + + .. math:: + \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) + + where $\mu$ is a uniform weighted empirical distribution of :math:`\mu_a` and :math:`\mu_b` are the empirical measures associated + to the samples in the source and target domains, and :math:`W_2` is the + Wasserstein distance. This problem is solved using exact OT solvers for + `reg=None` and the Sinkhorn solver for `reg!=None`. The solution provides + two transport plans that can be used to recover a low rank OT plan between + the two distributions. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='factored', rank=10) + + # recover the lazy low rank plan + factored_solution_lazy = res.lazy_plan + + # recover the full low rank plan + factored_solution = factored_solution_lazy[:] + + - **Gaussian Bures-Wasserstein [2]** (when ``method='gaussian'``): + + This method computes the Gaussian Bures-Wasserstein distance between two + Gaussian distributions estimated from teh empirical distributions + + .. math:: + \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} + + where : + + .. math:: + \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) + + The covariances and means are estimated from the data. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='gaussian') + + # recover the squared Gaussian Bures-Wasserstein distance + BW_dist = res.value + + - **Wasserstein 1d [1]** (when ``method='1D'``): + + This method computes the Wasserstein distance between two 1d distributions + estimated from the empirical distributions. For multivariate data the + distances are computed independently for each dimension. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='1D') + + # recover the squared Wasserstein distances + W_dists = res.value + + + .. _references-solve-sample: + References + ---------- + + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. + (2011, December). Displacement interpolation using Lagrangian mass + transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. + 158). ACM. + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. + + .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse + Optimal Transport. Proceedings of the Twenty-First International + Conference on Artificial Intelligence and Statistics (AISTATS). + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, + A., & Peyré, G. (2019, April). Interpolating between optimal transport + and MMD using Sinkhorn divergences. In The 22nd International Conference + on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + + .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, + G., & Weed, J. (2019, April). Statistical optimal transport via factored + couplings. In The 22nd International Conference on Artificial + Intelligence and Statistics (pp. 2454-2465). PMLR. + + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + + .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). + Low-rank Sinkhorn Factorization. In International Conference on + Machine Learning. + + + """ + + if method is not None and method.lower() in lst_method_lazy: + lazy0 = lazy + lazy = True + + if not lazy: # default non lazy solver calls ot.solve + + # compute cost matrix M and use solve function + M = dist(X_a, X_b, metric) + + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose, grad) + + return res + + else: + + # Detect backend + nx = get_backend(X_a, X_b, a, b) + + # default values for solutions + potentials = None + value = None + value_linear = None + plan = None + lazy_plan = None + status = None + log = None + + method = method.lower() if method is not None else '' + + if method == '1d': # Wasserstein 1d (parallel on all dimensions) + if metric == 'sqeuclidean': + p = 2 + elif metric in ['euclidean', 'cityblock']: + p = 1 + else: + raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + + value = wasserstein_1d(X_a, X_b, a, b, p=p) + value_linear = value + + elif method == 'gaussian': # Gaussian Bures-Wasserstein + + if not metric.lower() in ['sqeuclidean']: + raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + + if reg is None: + reg = 1e-6 + + value, log = empirical_bures_wasserstein_distance(X_a, X_b, reg=reg, log=True) + value = value**2 # return the value (squared bures distance) + value_linear = value # return the value + + elif method == 'factored': # Factored OT + + if not metric.lower() in ['sqeuclidean']: + raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + + if max_iter is None: + max_iter = 100 + if tol is None: + tol = 1e-7 + if reg is None: + reg = 0 + + Q, R, X, log = factored_optimal_transport(X_a, X_b, reg=reg, r=rank, log=True, stopThr=tol, numItermax=max_iter, verbose=verbose) + log['X'] = X + + value_linear = log['costa'] + log['costb'] + value = value_linear # TODO add reg term + lazy_plan = log['lazy_plan'] + if not lazy0: # store plan if not lazy + plan = lazy_plan[:] + + elif method == "lowrank": + + if not metric.lower() in ['sqeuclidean']: + raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + + if max_iter is None: + max_iter = 2000 + if tol is None: + tol = 1e-7 + if reg is None: + reg = 0 + + Q, R, g, log = lowrank_sinkhorn(X_a, X_b, rank=rank, reg=reg, a=a, b=b, numItermax=max_iter, stopThr=tol, log=True) + value = log['value'] + value_linear = log['value_linear'] + lazy_plan = log['lazy_plan'] + if not lazy0: # store plan if not lazy + plan = lazy_plan[:] + + elif method.startswith('geomloss'): # Geomloss solver for entropic OT + + split_method = method.split('_') + if len(split_method) == 2: + backend = split_method[1] + else: + if lazy0 is None: + backend = 'auto' + elif lazy0: + backend = 'online' + else: + backend = 'tensorized' + + value, log = empirical_sinkhorn2_geomloss(X_a, X_b, reg=reg, a=a, b=b, metric=metric, log=True, verbose=verbose, scaling=scaling, backend=backend) + + lazy_plan = log['lazy_plan'] + if not lazy0: # store plan if not lazy + plan = lazy_plan[:] + + # return scaled potentials (to be consistent with other solvers) + potentials = (log['f'] / (lazy_plan.blur**2), log['g'] / (lazy_plan.blur**2)) + + elif reg is None or reg == 0: # exact OT + + if unbalanced is None: # balanced EMD solver not available for lazy + raise (NotImplementedError('Exact OT solver with lazy=True not implemented')) + + else: + raise (NotImplementedError('Non regularized solver with unbalanced_type="{}" not implemented'.format(unbalanced_type))) + + else: + if unbalanced is None: + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if batch_size is None: + batch_size = 100 + + value_linear, log = empirical_sinkhorn2(X_a, X_b, reg, a, b, metric=metric, numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + # compute potentials + potentials = (log["u"], log["v"]) + lazy_plan = log['lazy_plan'] + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) + + res = OTResult(potentials=potentials, value=value, lazy_plan=lazy_plan, + value_linear=value_linear, plan=plan, status=status, backend=nx, log=log) + return res From fc8b96c0b2f6a2d43275327919ec785741cc96b9 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Wed, 17 Jul 2024 14:04:10 +0200 Subject: [PATCH 04/17] fix dictionary key error --- ot/solvers.py | 2 +- ot/unbalanced/_lbfgs.py | 2 -- ot/unbalanced/_mm.py | 2 -- test/test_solvers.py | 2 +- 4 files changed, 2 insertions(+), 6 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 95165ea11..6c8865560 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -403,7 +403,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, value_linear = nx.sum(M * plan) - value = log['loss'] + value = log['cost'] else: raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index 734924948..f03a8820d 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -9,9 +9,7 @@ # # License: MIT License -from __future__ import division import warnings - import numpy as np from scipy.optimize import minimize, Bounds diff --git a/ot/unbalanced/_mm.py b/ot/unbalanced/_mm.py index 464c89f78..6b9a27800 100644 --- a/ot/unbalanced/_mm.py +++ b/ot/unbalanced/_mm.py @@ -9,8 +9,6 @@ # # License: MIT License -from __future__ import division - from ..backend import get_backend from ..utils import list_to_array, get_parameter_pair diff --git a/test/test_solvers.py b/test/test_solvers.py index 16e6df295..160451207 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -16,7 +16,7 @@ lst_reg = [None, 1] -lst_reg_type = ['KL', 'entropy', 'L2', 'tuple'] +lst_reg_type = ['KL', 'L2', 'tuple'] lst_unbalanced = [None, 0.9] lst_unbalanced_type = ['KL', 'L2', 'TV'] From ebd553d7b7560dc563903473098c09386d9caa40 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Sun, 21 Jul 2024 20:39:52 +0200 Subject: [PATCH 05/17] fix documentation error --- ot/unbalanced/_lbfgs.py | 25 ++++++++++++------------- ot/unbalanced/_sinkhorn.py | 14 +++++++------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index f03a8820d..cde6a7192 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -87,10 +87,13 @@ def _func(G): G = G.reshape((m, n)) # compute loss - val = np.sum(G * M) + reg * reg_fun(G) + regm_fun(G) - + val = np.sum(G * M) + regm_fun(G) + if reg > 0: + val = val + reg * reg_fun(G) # compute gradient - grad = M + reg * grad_reg_fun(G) + grad_regm_fun(G) + grad = M + grad_regm_fun(G) + if reg > 0: + grad = grad + reg * grad_reg_fun(G) return val, grad.ravel() @@ -105,7 +108,7 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div_m}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -203,8 +206,6 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) if reg > 0: # regularized case c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) - else: # unregularized case - c = 0 # wrap the callable function to handle numpy arrays if isinstance(reg_div, tuple): @@ -248,7 +249,7 @@ def lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div_m}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -319,12 +320,10 @@ def lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[1., 36.],[9., 4.]] - >>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='kl', regm_div='kl'), 2) - array([[0.45, 0. ], - [0. , 0.34]]) - >>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='l2', regm_div='l2'), 2) - array([[0.4, 0. ], - [0. , 0.1]]) + >>> np.round(ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=0, reg_m=5, reg_div='kl', regm_div='kl'), 2) + 0.8 + >>> np.round(ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=0, reg_m=5, reg_div='l2', regm_div='l2'), 2) + 1.79 References ---------- diff --git a/ot/unbalanced/_sinkhorn.py b/ot/unbalanced/_sinkhorn.py index f49f70c46..678fc06f4 100644 --- a/ot/unbalanced/_sinkhorn.py +++ b/ot/unbalanced/_sinkhorn.py @@ -104,8 +104,8 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, c=None, method='sinkhorn', >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1) - array([[0.51122814, 0.18807032], - [0.18807032, 0.51122814]]) + array([[0.32205361, 0.1184769 ], + [0.1184769 , 0.32205361]]) .. _references-sinkhorn-unbalanced: References @@ -240,7 +240,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, c=None, method='sinkhorn', >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> np.round(ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.), 8) - 0.31912858 + 0.19600125 .. _references-sinkhorn-unbalanced2: References @@ -402,8 +402,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c=None, >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) - array([[0.51122814, 0.18807032], - [0.18807032, 0.51122814]]) + array([[0.32205361, 0.1184769 ], + [0.1184769 , 0.32205361]]) .. _references-sinkhorn-knopp-unbalanced: References @@ -619,8 +619,8 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, c=None, >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) - array([[0.51122814, 0.18807032], - [0.18807032, 0.51122814]]) + array([[0.32205361, 0.1184769 ], + [0.1184769 , 0.32205361]]) .. _references-sinkhorn-stabilized-unbalanced: References From c835b633cdb4602db5932de9d5319b794ce49576 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Sun, 21 Jul 2024 21:15:14 +0200 Subject: [PATCH 06/17] fix doctest error --- ot/unbalanced/_lbfgs.py | 4 ++-- ot/unbalanced/_sinkhorn.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index cde6a7192..3d831b390 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -321,9 +321,9 @@ def lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', >>> b=[.5, .5] >>> M=[[1., 36.],[9., 4.]] >>> np.round(ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=0, reg_m=5, reg_div='kl', regm_div='kl'), 2) - 0.8 - >>> np.round(ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=0, reg_m=5, reg_div='l2', regm_div='l2'), 2) 1.79 + >>> np.round(ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=0, reg_m=5, reg_div='l2', regm_div='l2'), 2) + 0.8 References ---------- diff --git a/ot/unbalanced/_sinkhorn.py b/ot/unbalanced/_sinkhorn.py index 678fc06f4..002fea0bd 100644 --- a/ot/unbalanced/_sinkhorn.py +++ b/ot/unbalanced/_sinkhorn.py @@ -103,9 +103,9 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, c=None, method='sinkhorn', >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] - >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1) - array([[0.32205361, 0.1184769 ], - [0.1184769 , 0.32205361]]) + >>> ot.round(ot.sinkhorn_unbalanced(a, b, M, 1, 1), 7) + array([[0.3220536, 0.1184769], + [0.1184769, 0.3220536]]) .. _references-sinkhorn-unbalanced: References @@ -401,9 +401,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c=None, >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) - array([[0.32205361, 0.1184769 ], - [0.1184769 , 0.32205361]]) + >>> ot.round(ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.), 7) + array([[0.3220536, 0.1184769], + [0.1184769, 0.3220536]]) .. _references-sinkhorn-knopp-unbalanced: References @@ -618,9 +618,9 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, c=None, >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) - array([[0.32205361, 0.1184769 ], - [0.1184769 , 0.32205361]]) + >>> np.round(ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.), 7) + array([[0.3220536, 0.1184769], + [0.1184769, 0.3220536]]) .. _references-sinkhorn-stabilized-unbalanced: References From 1f478cc4d50d340822419db7e0c1062fc1e5b223 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Sun, 21 Jul 2024 21:34:06 +0200 Subject: [PATCH 07/17] fix doctest error --- ot/unbalanced/_sinkhorn.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ot/unbalanced/_sinkhorn.py b/ot/unbalanced/_sinkhorn.py index 002fea0bd..396ee04ba 100644 --- a/ot/unbalanced/_sinkhorn.py +++ b/ot/unbalanced/_sinkhorn.py @@ -100,10 +100,11 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, c=None, method='sinkhorn', -------- >>> import ot + >>> import numpy as np >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] - >>> ot.round(ot.sinkhorn_unbalanced(a, b, M, 1, 1), 7) + >>> np.round(ot.sinkhorn_unbalanced(a, b, M, 1, 1), 7) array([[0.3220536, 0.1184769], [0.1184769, 0.3220536]]) @@ -398,10 +399,11 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c=None, -------- >>> import ot + >>> import numpy as np >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.round(ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.), 7) + >>> np.round(ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.), 7) array([[0.3220536, 0.1184769], [0.1184769, 0.3220536]]) @@ -615,6 +617,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, c=None, -------- >>> import ot + >>> import numpy as np >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] From 7cf6290fe1891acc53dffb1fc1ac4eafafd73487 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Thu, 5 Sep 2024 15:03:15 +0200 Subject: [PATCH 08/17] recover reg_type argument --- ot/unbalanced/_lbfgs.py | 18 +++++-- ot/unbalanced/_sinkhorn.py | 86 +++++++++++++++++++++++--------- test/unbalanced/test_lbfgs.py | 6 +-- test/unbalanced/test_sinkhorn.py | 43 ++++++++-------- 4 files changed, 101 insertions(+), 52 deletions(-) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index 3d831b390..99a8805b2 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -40,9 +40,18 @@ def reg_kl(G): def grad_kl(G): return np.log(G / c + 1e-16) + def reg_entropy(G): + return np.sum(G * np.log(G + 1e-16)) - np.sum(G) + + def grad_entropy(G): + return np.log(G + 1e-16) + if reg_div == 'kl': reg_fun = reg_kl grad_reg_fun = grad_kl + elif reg_div == 'entropy': + reg_fun = reg_entropy + grad_reg_fun = grad_entropy elif isinstance(reg_div, tuple): reg_fun = reg_div[0] grad_reg_fun = reg_div[1] @@ -147,7 +156,8 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', If reg_m is an array, it must be a Numpy array. reg_div: string, optional Divergence used for regularization. - Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) or a tuple + Can take three values: 'entropy' (negative entropy), or + 'kl' (Kullback-Leibler) or 'l2' (quadratic) or a tuple of two calable functions returning the reg term and its derivative. Note that the callable functions should be able to handle numpy arrays and not tesors from the backend @@ -204,8 +214,7 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', # convert to numpy a, b, M = nx.to_numpy(a, b, M) G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) - if reg > 0: # regularized case - c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) + c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) # wrap the callable function to handle numpy arrays if isinstance(reg_div, tuple): @@ -288,7 +297,8 @@ def lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', If reg_m is an array, it must be a Numpy array. reg_div: string, optional Divergence used for regularization. - Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) or a tuple + Can take three values: 'entropy' (negative entropy), or + 'kl' (Kullback-Leibler) or 'l2' (quadratic) or a tuple of two calable functions returning the reg term and its derivative. Note that the callable functions should be able to handle numpy arrays and not tesors from the backend diff --git a/ot/unbalanced/_sinkhorn.py b/ot/unbalanced/_sinkhorn.py index 396ee04ba..7d22a8376 100644 --- a/ot/unbalanced/_sinkhorn.py +++ b/ot/unbalanced/_sinkhorn.py @@ -16,9 +16,9 @@ from ..utils import list_to_array, get_parameter_pair -def sinkhorn_unbalanced(a, b, M, reg, reg_m, c=None, method='sinkhorn', - warmstart=None, numItermax=1000, stopThr=1e-6, - verbose=False, log=False, **kwargs): +def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', + reg_type="kl", c=None, warmstart=None, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the unbalanced entropic regularization optimal transport problem and return the OT plan @@ -64,12 +64,20 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, c=None, method='sinkhorn', For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). - c : array-like (dim_a, dim_b), optional (default=None) - Reference measure for the regularization. - If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or 'sinkhorn_reg_scaling', see those function for specific parameters + reg_type : string, optional + Regularizer term. Can take two values: + reg_type = 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`. + This is equivalent (up to a constant) to :math:`\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)`. + reg_type = 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + c : array-like (dim_a, dim_b), optional (default=None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If reg_type = 'entropy', then `\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors). @@ -138,20 +146,20 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, c=None, method='sinkhorn', """ if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, c, + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) @@ -159,9 +167,9 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, c=None, method='sinkhorn', raise ValueError("Unknown method '%s'." % method) -def sinkhorn_unbalanced2(a, b, M, reg, reg_m, c=None, method='sinkhorn', - warmstart=None, numItermax=1000, stopThr=1e-6, - verbose=False, log=False, **kwargs): +def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', + reg_type="kl", c=None, warmstart=None, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -206,12 +214,20 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, c=None, method='sinkhorn', For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). - c : array-like (dim_a, dim_b), optional (default=None) - Reference measure for the regularization. - If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or 'sinkhorn_reg_scaling', see those function for specific parameterss + reg_type : string, optional + Regularizer term. Can take two values: + reg_type = 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`. + This is equivalent (up to a constant) to :math:`\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)`. + reg_type = 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + c : array-like (dim_a, dim_b), optional (default=None) + Reference measure for the regularization. + If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If reg_type = 'entropy', then `\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors). @@ -273,19 +289,19 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, c=None, method='sinkhorn', if len(b.shape) < 2: if method.lower() == 'sinkhorn': - res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c, + res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - res = sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, c, + res = sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') - res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c, + res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) @@ -299,19 +315,19 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, c=None, method='sinkhorn', else: if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, c, + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) @@ -319,7 +335,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, c=None, method='sinkhorn', raise ValueError('Unknown method %s.' % method) -def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c=None, +def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, warmstart=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" @@ -366,9 +382,17 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c=None, For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional + Regularizer term. Can take two values: + reg_type = 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`. + This is equivalent (up to a constant) to :math:`\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)`. + reg_type = 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. c : array-like (dim_a, dim_b), optional (default=None) Reference measure for the regularization. If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If reg_type = 'entropy', then `\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors). @@ -457,6 +481,10 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c=None, else: u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) + if reg_type == "entropy": + warnings.warn('If reg_type = entropy, then the matrix c is overwritten by the one matrix.') + c = nx.ones((dim_a, dim_b), type_as=M) + if n_hists: K = nx.exp(-M / reg) else: @@ -533,7 +561,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, c=None, return plan -def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, c=None, +def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, warmstart=None, tau=1e5, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): @@ -583,9 +611,17 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, c=None, For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional + Regularizer term. Can take two values: + reg_type = 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`. + This is equivalent (up to a constant) to :math:`\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)`. + reg_type = 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. c : array-like (dim_a, dim_b), optional (default=None) Reference measure for the regularization. If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If reg_type = 'entropy', then `\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors). @@ -674,6 +710,10 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, c=None, else: u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) + if reg_type == "entropy": + warnings.warn('If reg_type = entropy, then the matrix c is overwritten by the one matrix.') + c = nx.ones((dim_a, dim_b), type_as=M) + if n_hists: M0 = M else: diff --git a/test/unbalanced/test_lbfgs.py b/test/unbalanced/test_lbfgs.py index 24f127d06..f3651ced7 100644 --- a/test/unbalanced/test_lbfgs.py +++ b/test/unbalanced/test_lbfgs.py @@ -13,7 +13,7 @@ import pytest -@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2'], ['kl', 'l2'], ['linear', 'total'])) +@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'], ['linear', 'total'])) def test_lbfgsb_unbalanced(nx, reg_div, regm_div, returnCost): np.random.seed(42) @@ -46,7 +46,7 @@ def test_lbfgsb_unbalanced(nx, reg_div, regm_div, returnCost): np.testing.assert_allclose(loss, nx.to_numpy(loss0), atol=1e-06) -@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2'], ['kl', 'l2'], ['linear', 'total'])) +@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'], ['linear', 'total'])) def test_lbfgsb_unbalanced_relaxation_parameters(nx, reg_div, regm_div, returnCost): np.random.seed(42) @@ -93,7 +93,7 @@ def test_lbfgsb_unbalanced_relaxation_parameters(nx, reg_div, regm_div, returnCo np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-06) -@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2'], ['kl', 'l2'], ['linear', 'total'])) +@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'], ['linear', 'total'])) def test_lbfgsb_reference_measure(nx, reg_div, regm_div, returnCost): np.random.seed(42) diff --git a/test/unbalanced/test_sinkhorn.py b/test/unbalanced/test_sinkhorn.py index 4d32660fd..130b08674 100644 --- a/test/unbalanced/test_sinkhorn.py +++ b/test/unbalanced/test_sinkhorn.py @@ -14,8 +14,8 @@ from ot.unbalanced import barycenter_unbalanced -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"]) -def test_unbalanced_convergence(nx, method): +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) +def test_unbalanced_convergence(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -30,26 +30,27 @@ def test_unbalanced_convergence(nx, method): epsilon = 1. reg_m = 1. - stopThr = 1e-12 G, log = ot.unbalanced.sinkhorn_unbalanced( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - c=None, log=True, verbose=True, numItermax=1000, stopThr=stopThr + reg_type=reg_type, log=True, verbose=True ) loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - c=None, verbose=True, numItermax=1000, stopThr=stopThr + reg_type=reg_type, verbose=True )) # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) - logb = nx.log(b + 1e-16) loga = nx.log(a + 1e-16) - log_ab = loga[:, None] + logb[None, :] - - logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon + log_ab.T, axis=1) - logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon + log_ab, axis=1) + if reg_type == "entropy": + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1) + elif reg_type == "kl": + log_ab = loga[:, None] + logb[None, :] + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon + log_ab.T, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon + log_ab, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) @@ -68,19 +69,17 @@ def test_unbalanced_convergence(nx, method): G = ot.unbalanced.sinkhorn_unbalanced( a, b, M, reg=epsilon, reg_m=reg_m, - method=method, c=None, verbose=True, - stopThr=stopThr + method=method, reg_type=reg_type, verbose=True ) G_np = ot.unbalanced.sinkhorn_unbalanced( a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, - method=method, c=None, verbose=True, - stopThr=stopThr + method=method, reg_type=reg_type, verbose=True ) np.testing.assert_allclose(G_np, nx.to_numpy(G)) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"]) -def test_unbalanced_warmstart(nx, method): +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) +def test_unbalanced_warmstart(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -96,33 +95,33 @@ def test_unbalanced_warmstart(nx, method): G0, log0 = ot.unbalanced.sinkhorn_unbalanced( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - c=None, warmstart=None, log=True, verbose=True + reg_type=reg_type, warmstart=None, log=True, verbose=True ) loss0 = ot.unbalanced.sinkhorn_unbalanced2( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - c=None, warmstart=None, verbose=True + reg_type=reg_type, warmstart=None, verbose=True ) dim_a, dim_b = M.shape warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)) G, log = ot.unbalanced.sinkhorn_unbalanced( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - c=None, warmstart=warmstart, log=True, verbose=True + reg_type=reg_type, warmstart=warmstart, log=True, verbose=True ) loss = ot.unbalanced.sinkhorn_unbalanced2( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - c=None, warmstart=warmstart, verbose=True + reg_type=reg_type, warmstart=warmstart, verbose=True ) _, log_emd = ot.lp.emd(a, b, M, log=True) warmstart1 = (log_emd["u"], log_emd["v"]) G1, log1 = ot.unbalanced.sinkhorn_unbalanced( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - c=None, warmstart=warmstart1, log=True, verbose=True + reg_type=reg_type, warmstart=warmstart1, log=True, verbose=True ) loss1 = ot.unbalanced.sinkhorn_unbalanced2( a, b, M, reg=epsilon, reg_m=reg_m, method=method, - c=None, warmstart=warmstart1, verbose=True + reg_type=reg_type, warmstart=warmstart1, verbose=True ) np.testing.assert_allclose( From c339d55a09bfe99946c2d932c62d2855e5e8f050 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Mon, 9 Sep 2024 10:40:57 +0200 Subject: [PATCH 09/17] add documentation based on code review --- ot/unbalanced/__init__.py | 2 +- ot/unbalanced/_lbfgs.py | 114 +++++++++++++++------ ot/unbalanced/_mm.py | 40 ++++---- ot/unbalanced/_sinkhorn.py | 165 ++++++++++++++++++------------- test/unbalanced/test_lbfgs.py | 75 +++++++++++++- test/unbalanced/test_mm.py | 21 ++++ test/unbalanced/test_sinkhorn.py | 22 +++++ 7 files changed, 314 insertions(+), 125 deletions(-) diff --git a/ot/unbalanced/__init__.py b/ot/unbalanced/__init__.py index 67ef6cdf8..3a20af30d 100644 --- a/ot/unbalanced/__init__.py +++ b/ot/unbalanced/__init__.py @@ -19,7 +19,7 @@ from ._mm import (mm_unbalanced, mm_unbalanced2) -from ._lbfgs import (_get_loss_unbalanced, lbfgsb_unbalanced, lbfgsb_unbalanced2) +from ._lbfgs import (lbfgsb_unbalanced, lbfgsb_unbalanced2) __all__ = ['sinkhorn_knopp_unbalanced', 'sinkhorn_unbalanced', 'sinkhorn_stabilized_unbalanced', 'sinkhorn_unbalanced2', 'barycenter_unbalanced_sinkhorn', 'barycenter_unbalanced_stabilized', diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index 99a8805b2..2d5a1309b 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -19,14 +19,47 @@ def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='kl'): """ - return the loss function (scipy.optimize compatible) for regularized - unbalanced OT + Return loss function for the L-BFGS-B solver + + .. note:: This function will be fed into scipy.optimize, so all input arrays must be Numpy arrays. + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + M : array-like (dim_a, dim_b) + loss matrix + reg: float + regularization term >=0 + c : array-like (dim_a, dim_b), optional (default = None) + Reference measure for the regularization. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + reg_m1: float + Marginal relaxation term with respect to the first marginal: + nonnegative (including 0) but cannot be infinity. + reg_m2: float + Marginal relaxation term with respect to the second marginal: + nonnegative (including 0) but cannot be infinity. + reg_div: string, optional + Divergence used for regularization. + Can take three values: 'entropy' (negative entropy), or + 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple + of two calable functions returning the reg term and its derivative. + Note that the callable functions should be able to handle Numpy arrays + and not tesors from the backend + regm_div: string, optional + Divergence to quantify the difference between the marginals. + Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) + + Returns + ------- + Loss function (scipy.optimize compatible) for regularized unbalanced OT """ m, n = M.shape - - def kl(p, q): - return np.sum(p * np.log(p / q + 1e-16)) - np.sum(p) + np.sum(q) + nx_numpy = get_backend(M, a, b) def reg_l2(G): return np.sum((G - c)**2) / 2 @@ -35,7 +68,7 @@ def grad_l2(G): return G - c def reg_kl(G): - return kl(G, c) + return nx_numpy.kl_div(G, c, mass=True) def grad_kl(G): return np.log(G / c + 1e-16) @@ -68,7 +101,7 @@ def grad_marg_l2(G): reg_m2 * np.outer(np.ones(m), (G.sum(0) - b)) def marg_kl(G): - return reg_m1 * kl(G.sum(1), a) + reg_m2 * kl(G.sum(0), b) + return reg_m1 * nx_numpy.kl_div(G.sum(1), a, mass=True) + reg_m2 * nx_numpy.kl_div(G.sum(0), b, mass=True) def grad_marg_kl(G): return reg_m1 * np.outer(np.log(G.sum(1) / a + 1e-16), np.ones(n)) + \ @@ -112,11 +145,11 @@ def _func(G): def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False): r""" - Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B. + Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B algorithm. The function solves the following optimization problem: .. math:: - W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + W = \arg \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div_m}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -134,7 +167,9 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler divergence, or half-squared :math:`\ell_2` divergence - The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. First, it converts all arrays into Numpy arrays, + then uses the L-BFGS-B algorithm from scipy.optimize to solve the optimization problem. Parameters ---------- @@ -148,22 +183,22 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', regularization term >=0 c : array-like (dim_a, dim_b), optional (default = None) Reference measure for the regularization. - If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. reg_m: float or indexable object of length 1 or 2 - Marginal relaxation term >= 0, but cannot be infinity. - If reg_m is a scalar or an indexable object of length 1, - then the same reg_m is applied to both marginal relaxations. - If reg_m is an array, it must be a Numpy array. + Marginal relaxation term: nonnegative (including 0) but cannot be infinity. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array. reg_div: string, optional Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or - 'kl' (Kullback-Leibler) or 'l2' (quadratic) or a tuple + 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple of two calable functions returning the reg term and its derivative. - Note that the callable functions should be able to handle numpy arrays + Note that the callable functions should be able to handle Numpy arrays and not tesors from the backend regm_div: string, optional Divergence to quantify the difference between the marginals. - Can take three values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) or 'tv' (Total Variation) + Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) G0: array-like (dim_a, dim_b) Initialization of the transport matrix numItermax : int, optional @@ -207,10 +242,22 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ + if reg_div not in ["entropy", "kl", "l2"]: + raise ValueError("Unknown reg_div = {}. Must be either 'entropy', 'kl' or 'l2'".format(reg_div)) + if regm_div not in ["kl", "l2", "tv"]: + raise ValueError("Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div)) + M, a, b = list_to_array(M, a, b) nx = get_backend(M, a, b) M0 = M + dim_a, dim_b = M.shape + + if len(a) == 0 or a is None: + a = nx.ones(dim_a, type_as=M) / dim_a + if len(b) == 0 or b is None: + b = nx.ones(dim_b, type_as=M) / dim_b + # convert to numpy a, b, M = nx.to_numpy(a, b, M) G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) @@ -253,11 +300,11 @@ def lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0=None, returnCost="linear", numItermax=1000, stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False): r""" - Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B. + Solve the unbalanced optimal transport problem and return the OT cost using L-BFGS-B. The function solves the following optimization problem: .. math:: - W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div_m}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -275,7 +322,9 @@ def lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler divergence, or half-squared :math:`\ell_2` divergence - The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. First, it converts all arrays into Numpy arrays, + then uses the L-BFGS-B algorithm from scipy.optimize to solve the optimization problem. Parameters ---------- @@ -289,24 +338,27 @@ def lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', regularization term >=0 c : array-like (dim_a, dim_b), optional (default = None) Reference measure for the regularization. - If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. reg_m: float or indexable object of length 1 or 2 - Marginal relaxation term >= 0, but cannot be infinity. - If reg_m is a scalar or an indexable object of length 1, - then the same reg_m is applied to both marginal relaxations. - If reg_m is an array, it must be a Numpy array. + Marginal relaxation term: nonnegative (including 0) but cannot be infinity. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + If :math:`\mathrm{reg_{m}}` is an array, it must be a Numpy array. reg_div: string, optional Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or - 'kl' (Kullback-Leibler) or 'l2' (quadratic) or a tuple + 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple of two calable functions returning the reg term and its derivative. - Note that the callable functions should be able to handle numpy arrays + Note that the callable functions should be able to handle Numpy arrays and not tesors from the backend regm_div: string, optional Divergence to quantify the difference between the marginals. - Can take three values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) or 'tv' (Total Variation) + Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) G0: array-like (dim_a, dim_b) Initialization of the transport matrix + returnCost: string, optional (default = "linear") + If `returnCost` = "linear", then return the linear part of the unbalanced OT loss. + If `returnCost` = "total", then return the total unbalanced OT loss. numItermax : int, optional Max number of iterations stopThr : float, optional @@ -318,8 +370,8 @@ def lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', Returns ------- - gamma : (dim_a, dim_b) array-like - Optimal transportation matrix for the given parameters + ot_cost : array-like + the OT cost between :math:`\mathbf{a}` and :math:`\mathbf{b}` log : dict log dictionary returned only if `log` is `True` diff --git a/ot/unbalanced/_mm.py b/ot/unbalanced/_mm.py index 6b9a27800..8dbb3a9cb 100644 --- a/ot/unbalanced/_mm.py +++ b/ot/unbalanced/_mm.py @@ -20,7 +20,7 @@ def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1 The function solves the following optimization problem: .. math:: - W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + W = \arg \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) @@ -48,19 +48,20 @@ def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1 M : array-like (dim_a, dim_b) loss matrix reg_m: float or indexable object of length 1 or 2 - Marginal relaxation term >= 0, but cannot be infinity. - If reg_m is a scalar or an indexable object of length 1, - then the same reg_m is applied to both marginal relaxations. - If reg_m is an array, it must have the same backend as input arrays (a, b, M). + Marginal relaxation term: nonnegative but cannot be infinity. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + If :math:`\mathrm{reg_{m}}` is an array, + it must have the same backend as input arrays `(a, b, M)`. reg : float, optional (default = 0) Regularization term >= 0. By default, solve the unregularized problem c : array-like (dim_a, dim_b), optional (default = None) Reference measure for the regularization. - If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. div: string, optional Divergence to quantify the difference between the marginals. - Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) G0: array-like (dim_a, dim_b) Initialization of the transport matrix numItermax : int, optional @@ -131,7 +132,7 @@ def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1 r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r K = (a[:, None]**r1) * (b[None, :]**r2) * (c**r) * nx.exp(- M / sum_r) elif div == 'l2': - K = reg_m1 * a[:, None] + reg_m2 * b[None, :] + reg * c - M + K = (reg_m1 * a[:, None]) + (reg_m2 * b[None, :]) + reg * c - M K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M)) else: raise ValueError("Unknown div = {}. Must be either 'kl' or 'l2'".format(div)) @@ -180,11 +181,11 @@ def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1 def mm_unbalanced2(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, returnCost="linear", numItermax=1000, stopThr=1e-15, verbose=False, log=False): r""" - Solve the unbalanced optimal transport problem and return the OT plan. + Solve the unbalanced optimal transport problem and return the OT cost. The function solves the following optimization problem: .. math:: - W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) @@ -212,24 +213,25 @@ def mm_unbalanced2(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, returnCost= M : array-like (dim_a, dim_b) loss matrix reg_m: float or indexable object of length 1 or 2 - Marginal relaxation term >= 0, but cannot be infinity. - If reg_m is a scalar or an indexable object of length 1, - then the same reg_m is applied to both marginal relaxations. - If reg_m is an array, it must have the same backend as input arrays (a, b, M). + Marginal relaxation term: nonnegative but cannot be infinity. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + If :math:`\mathrm{reg_{m}}` is an array, + it must have the same backend as input arrays `(a, b, M)`. reg : float, optional (default = 0) Entropy regularization term >= 0. By default, solve the unregularized problem c : array-like (dim_a, dim_b), optional (default = None) Reference measure for the regularization. - If None, then use `\mathbf{c} = mathbf{a} mathbf{b}^T`. + If None, then use :math:`\mathbf{c} = mathbf{a} mathbf{b}^T`. div: string, optional Divergence to quantify the difference between the marginals. - Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) G0: array-like (dim_a, dim_b) Initialization of the transport matrix returnCost: string, optional (default = "linear") - If returnCost = "linear", then return the linear part of the unbalanced OT loss. - If returnCost = "total", then return total unbalanced OT loss. + If `returnCost` = "linear", then return the linear part of the unbalanced OT loss. + If `returnCost` = "total", then return the total unbalanced OT loss. numItermax : int, optional Max number of iterations stopThr : float, optional @@ -241,7 +243,7 @@ def mm_unbalanced2(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, returnCost= Returns ------- - ot_distance : array-like + ot_cost : array-like the OT cost between :math:`\mathbf{a}` and :math:`\mathbf{b}` log : dict log dictionary returned only if `log` is `True` diff --git a/ot/unbalanced/_sinkhorn.py b/ot/unbalanced/_sinkhorn.py index 7d22a8376..7201c9750 100644 --- a/ot/unbalanced/_sinkhorn.py +++ b/ot/unbalanced/_sinkhorn.py @@ -26,7 +26,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', The function solves the following optimization problem: .. math:: - W = \min_\gamma \ \langle \gamma, \mathbf{M} \rangle_F + + W = \arg \min_\gamma \ \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -58,29 +58,31 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', Entropy regularization term > 0 reg_m: float or indexable object of length 1 or 2 Marginal relaxation term. - If reg_m is a scalar or an indexable object of length 1, - then the same reg_m is applied to both marginal relaxations. - The entropic balanced OT can be recovered using `reg_m=float("inf")`. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + The entropic balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`. For semi-relaxed case, use either - `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. - If reg_m is an array, it must have the same backend as input arrays (a, b, M). + :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or + :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`. + If :math:`\mathrm{reg_{m}}` is an array, + it must have the same backend as input arrays `(a, b, M)`. method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or 'sinkhorn_reg_scaling', see those function for specific parameters reg_type : string, optional Regularizer term. Can take two values: - reg_type = 'entropy' (negative entropy) + + Negative entropy: 'entropy': :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`. This is equivalent (up to a constant) to :math:`\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)`. - reg_type = 'kl' (Kullback-Leibler) + + Kullback-Leibler divergence: 'kl': :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. c : array-like (dim_a, dim_b), optional (default=None) Reference measure for the regularization. - If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. - If reg_type = 'entropy', then `\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given - (that is the logarithm of the u,v sinkhorn scaling vectors). + (that is the logarithm of the `u`, `v` sinkhorn scaling vectors). numItermax : int, optional Max number of iterations stopThr : float, optional @@ -88,7 +90,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', verbose : bool, optional Print information along iterations log : bool, optional - record log if True + record `log` if `True` Returns @@ -168,16 +170,17 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', - reg_type="kl", c=None, warmstart=None, numItermax=1000, + reg_type="kl", c=None, warmstart=None, + returnCost="linear", numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and - return the loss + return the cost The function solves the following optimization problem: .. math:: - W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -208,29 +211,34 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', Entropy regularization term > 0 reg_m: float or indexable object of length 1 or 2 Marginal relaxation term. - If reg_m is a scalar or an indexable object of length 1, - then the same reg_m is applied to both marginal relaxations. - The entropic balanced OT can be recovered using `reg_m=float("inf")`. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + The entropic balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`. For semi-relaxed case, use either - `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. - If reg_m is an array, it must have the same backend as input arrays (a, b, M). + :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or + :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`. + If :math:`\mathrm{reg_{m}}` is an array, + it must have the same backend as input arrays `(a, b, M)`. method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_reg_scaling', see those function for specific parameterss + 'sinkhorn_reg_scaling', see those function for specific parameters reg_type : string, optional Regularizer term. Can take two values: - reg_type = 'entropy' (negative entropy) + + Negative entropy: 'entropy': :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`. This is equivalent (up to a constant) to :math:`\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)`. - reg_type = 'kl' (Kullback-Leibler) + + Kullback-Leibler divergence: 'kl': :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. c : array-like (dim_a, dim_b), optional (default=None) Reference measure for the regularization. - If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. - If reg_type = 'entropy', then `\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors). + returnCost: string, optional (default = "linear") + If `returnCost` = "linear", then return the linear part of the unbalanced OT loss. + If `returnCost` = "total", then return the total unbalanced OT loss. numItermax : int, optional Max number of iterations stopThr : float, optional @@ -238,13 +246,13 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', verbose : bool, optional Print information along iterations log : bool, optional - record log if True + record `log` if `True` Returns ------- - ot_distance : (n_hists,) array-like - the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` + ot_cost : (n_hists,) array-like + the OT cost between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` log : dict log dictionary returned only if `log` is `True` @@ -285,33 +293,39 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', """ M, a, b = list_to_array(M, a, b) - nx = get_backend(M, a, b) if len(b.shape) < 2: if method.lower() == 'sinkhorn': res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=True, **kwargs) elif method.lower() == 'sinkhorn_stabilized': res = sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=True, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=True, **kwargs) else: raise ValueError('Unknown method %s.' % method) + if returnCost == "linear": + cost = res[1]['cost'] + elif returnCost == "total": + cost = res[1]['total_cost'] + else: + raise ValueError("Unknown returnCost = {}".format(returnCost)) + if log: - return nx.sum(M * res[0]), res[1] + return cost, res[1] else: - return nx.sum(M * res) + return cost else: if method.lower() == 'sinkhorn': @@ -345,7 +359,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, The function solves the following optimization problem: .. math:: - W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + W = \arg \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -376,26 +390,28 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, Entropy regularization term > 0 reg_m: float or indexable object of length 1 or 2 Marginal relaxation term. - If reg_m is a scalar or an indexable object of length 1, - then the same reg_m is applied to both marginal relaxations. - The entropic balanced OT can be recovered using `reg_m=float("inf")`. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + The entropic balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`. For semi-relaxed case, use either - `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. - If reg_m is an array, it must have the same backend as input arrays (a, b, M). + :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or + :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`. + If :math:`\mathrm{reg_{m}}` is an array, + it must have the same backend as input arrays `(a, b, M)`. reg_type : string, optional Regularizer term. Can take two values: - reg_type = 'entropy' (negative entropy) + + Negative entropy: 'entropy': :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`. This is equivalent (up to a constant) to :math:`\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)`. - reg_type = 'kl' (Kullback-Leibler) + + Kullback-Leibler divergence: 'kl': :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. c : array-like (dim_a, dim_b), optional (default=None) Reference measure for the regularization. - If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. - If reg_type = 'entropy', then `\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given - (that is the logarithm of the u,v sinkhorn scaling vectors). + (that is the logarithm of the `u`, `v` sinkhorn scaling vectors). numItermax : int, optional Max number of iterations stopThr : float, optional @@ -403,7 +419,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, verbose : bool, optional Print information along iterations log : bool, optional - record log if True + record `log` if `True` Returns @@ -414,8 +430,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) array-like - the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` + - ot_cost : (n_hists,) array-like + the OT cost between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -548,12 +564,13 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, if log: linear_cost = nx.sum(plan * M) + dict_log["cost"] = linear_cost + total_cost = linear_cost + reg * nx.kl_div(plan, c) if reg_m1 != float("inf"): total_cost = total_cost + reg_m1 * nx.kl_div(nx.sum(plan, 1), a) if reg_m2 != float("inf"): total_cost = total_cost + reg_m2 * nx.kl_div(nx.sum(plan, 0), b) - dict_log["cost"] = linear_cost dict_log["total_cost"] = total_cost return plan, dict_log @@ -573,7 +590,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, stabilization as proposed in :ref:`[10] `: .. math:: - W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + W = \arg \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) @@ -605,28 +622,33 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, Entropy regularization term > 0 reg_m: float or indexable object of length 1 or 2 Marginal relaxation term. - If reg_m is a scalar or an indexable object of length 1, - then the same reg_m is applied to both marginal relaxations. - The entropic balanced OT can be recovered using `reg_m=float("inf")`. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + The entropic balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`. For semi-relaxed case, use either - `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. - If reg_m is an array, it must have the same backend as input arrays (a, b, M). + :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or + :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`. + If :math:`\mathrm{reg_{m}}` is an array, + it must have the same backend as input arrays `(a, b, M)`. + method : str + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_reg_scaling', see those function for specific parameters reg_type : string, optional Regularizer term. Can take two values: - reg_type = 'entropy' (negative entropy) + + Negative entropy: 'entropy': :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`. This is equivalent (up to a constant) to :math:`\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)`. - reg_type = 'kl' (Kullback-Leibler) + + Kullback-Leibler divergence: 'kl': :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. c : array-like (dim_a, dim_b), optional (default=None) Reference measure for the regularization. - If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. - If reg_type = 'entropy', then `\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given - (that is the logarithm of the u,v sinkhorn scaling vectors). + (that is the logarithm of the `u`, `v` sinkhorn scaling vectors). tau : float - threshold for max value in u or v for log scaling + threshold for max value in `u` or `v` for log scaling numItermax : int, optional Max number of iterations stopThr : float, optional @@ -634,7 +656,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, verbose : bool, optional Print information along iterations log : bool, optional - record log if True + record `log` if `True` Returns @@ -645,8 +667,8 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) array-like - the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` + - ot_cost : (n_hists,) array-like + the OT cost between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` Examples @@ -682,9 +704,9 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, dim_a, dim_b = M.shape - if len(a) == 0: + if len(a) == 0 or a is None: a = nx.ones(dim_a, type_as=M) / dim_a - if len(b) == 0: + if len(b) == 0 or b is None: b = nx.ones(dim_b, type_as=M) / dim_b if len(b.shape) > 1: @@ -813,12 +835,13 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, plan = nx.exp(logu[:, None] + logv[None, :] - M0 / reg) if log: linear_cost = nx.sum(plan * M) + dict_log["cost"] = linear_cost + total_cost = linear_cost + reg * nx.kl_div(plan, c) if reg_m1 != float("inf"): total_cost = total_cost + reg_m1 * nx.kl_div(nx.sum(plan, 1), a) if reg_m2 != float("inf"): total_cost = total_cost + reg_m2 * nx.kl_div(nx.sum(plan, 0), b) - dict_log["cost"] = linear_cost dict_log["total_cost"] = total_cost return plan, dict_log @@ -868,7 +891,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, verbose : bool, optional Print information along iterations log : bool, optional - record log if True + record `log` if `True` Returns @@ -876,7 +899,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict - log dictionary return only if log==True in parameters + log dictionary return only if :math:`log==True` in parameters .. _references-barycenter-unbalanced-stabilized: @@ -1016,7 +1039,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, verbose : bool, optional Print information along iterations log : bool, optional - record log if True + record `log` if `True` Returns @@ -1024,7 +1047,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict - log dictionary return only if log==True in parameters + log dictionary return only if :math:`log==True` in parameters .. _references-barycenter-unbalanced-sinkhorn: diff --git a/test/unbalanced/test_lbfgs.py b/test/unbalanced/test_lbfgs.py index f3651ced7..4f3ae8177 100644 --- a/test/unbalanced/test_lbfgs.py +++ b/test/unbalanced/test_lbfgs.py @@ -13,7 +13,7 @@ import pytest -@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'], ['linear', 'total'])) +@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2', 'tv'], ['linear', 'total'])) def test_lbfgsb_unbalanced(nx, reg_div, regm_div, returnCost): np.random.seed(42) @@ -46,7 +46,7 @@ def test_lbfgsb_unbalanced(nx, reg_div, regm_div, returnCost): np.testing.assert_allclose(loss, nx.to_numpy(loss0), atol=1e-06) -@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'], ['linear', 'total'])) +@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2', 'tv'], ['linear', 'total'])) def test_lbfgsb_unbalanced_relaxation_parameters(nx, reg_div, regm_div, returnCost): np.random.seed(42) @@ -93,7 +93,7 @@ def test_lbfgsb_unbalanced_relaxation_parameters(nx, reg_div, regm_div, returnCo np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-06) -@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2'], ['linear', 'total'])) +@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2', 'tv'], ['linear', 'total'])) def test_lbfgsb_reference_measure(nx, reg_div, regm_div, returnCost): np.random.seed(42) @@ -124,3 +124,72 @@ def test_lbfgsb_reference_measure(nx, reg_div, regm_div, returnCost): np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-06) + + +def test_lbfgsb_wrong_divergence(nx): + + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + + def lbfgsb_div(div): + return ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, reg_div=div) + + def lbfgsb2_div(div): + return ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=1, reg_m=10, reg_div=div) + + np.testing.assert_raises(ValueError, lbfgsb_div, "div_not_existed") + np.testing.assert_raises(ValueError, lbfgsb2_div, "div_not_existed") + + +def test_lbfgsb_wrong_marginal_divergence(nx): + + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + + def lbfgsb_div(div): + return ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, regm_div=div) + + def lbfgsb2_div(div): + return ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=1, reg_m=10, regm_div=div) + + np.testing.assert_raises(ValueError, lbfgsb_div, "div_not_existed") + np.testing.assert_raises(ValueError, lbfgsb2_div, "div_not_existed") + + +def test_lbfgsb_wrong_returnCost(nx): + + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + + def lbfgsb2(returnCost): + return ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=1, reg_m=10, + returnCost=returnCost, verbose=True) + + np.testing.assert_raises(ValueError, lbfgsb2, "invalid_returnCost") diff --git a/test/unbalanced/test_mm.py b/test/unbalanced/test_mm.py index 956f69de9..c68fc1a93 100644 --- a/test/unbalanced/test_mm.py +++ b/test/unbalanced/test_mm.py @@ -162,3 +162,24 @@ def mm2_div(div): np.testing.assert_raises(ValueError, mm_div, "div_not_existed") np.testing.assert_raises(ValueError, mm2_div, "div_not_existed") + + +def test_mm_wrong_returnCost(nx): + + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + + def mm2(returnCost): + return ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=100, reg=1e-2, + returnCost=returnCost, verbose=True) + + np.testing.assert_raises(ValueError, mm2, "invalid_returnCost") diff --git a/test/unbalanced/test_sinkhorn.py b/test/unbalanced/test_sinkhorn.py index 130b08674..01a59dc74 100644 --- a/test/unbalanced/test_sinkhorn.py +++ b/test/unbalanced/test_sinkhorn.py @@ -295,6 +295,28 @@ def test_stabilized_vs_sinkhorn(nx): np.testing.assert_allclose(G2, G2_np, atol=1e-5) +def test_sinkhorn_wrong_returnCost(nx): + + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + epsilon = 1 + reg_m = 1. + + def sinkhorn2(returnCost): + return ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, returnCost=returnCost) + + np.testing.assert_raises(ValueError, sinkhorn2, "invalid_returnCost") + + @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_unbalanced_barycenter(nx, method): # test generalized sinkhorn for unbalanced OT barycenter From e08726be4f36df99cada6e61c9486a7c030c2345 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Mon, 9 Sep 2024 10:43:36 +0200 Subject: [PATCH 10/17] recover previous code --- test/test_solvers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_solvers.py b/test/test_solvers.py index 160451207..16e6df295 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -16,7 +16,7 @@ lst_reg = [None, 1] -lst_reg_type = ['KL', 'L2', 'tuple'] +lst_reg_type = ['KL', 'entropy', 'L2', 'tuple'] lst_unbalanced = [None, 0.9] lst_unbalanced_type = ['KL', 'L2', 'TV'] From 52ef1f8ac54afb42b017e8b0d391cb48b68ee81a Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Mon, 9 Sep 2024 11:09:46 +0200 Subject: [PATCH 11/17] fix bug in tests --- ot/unbalanced/_lbfgs.py | 1 + ot/unbalanced/_mm.py | 1 + test/test_solvers.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index 2d5a1309b..c82f6fe5d 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -242,6 +242,7 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ + reg_div, regm_div = reg_div.lower(), regm_div.lower() if reg_div not in ["entropy", "kl", "l2"]: raise ValueError("Unknown reg_div = {}. Must be either 'entropy', 'kl' or 'l2'".format(reg_div)) if regm_div not in ["kl", "l2", "tv"]: diff --git a/ot/unbalanced/_mm.py b/ot/unbalanced/_mm.py index 8dbb3a9cb..73f6cd9f5 100644 --- a/ot/unbalanced/_mm.py +++ b/ot/unbalanced/_mm.py @@ -127,6 +127,7 @@ def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1 if log: log = {'err': [], 'G': []} + div = div.lower() if div == 'kl': sum_r = reg + reg_m1 + reg_m2 r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r diff --git a/test/test_solvers.py b/test/test_solvers.py index 16e6df295..863c35f76 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -16,7 +16,7 @@ lst_reg = [None, 1] -lst_reg_type = ['KL', 'entropy', 'L2', 'tuple'] +lst_reg_type = ['KL', 'entropy', 'L2'] lst_unbalanced = [None, 0.9] lst_unbalanced_type = ['KL', 'L2', 'TV'] From b1a737079d92a0fbde2125c73e04b4d3b90efb7f Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Mon, 9 Sep 2024 13:25:02 +0200 Subject: [PATCH 12/17] add test to test_sinkhorn --- test/unbalanced/test_sinkhorn.py | 45 ++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/test/unbalanced/test_sinkhorn.py b/test/unbalanced/test_sinkhorn.py index 01a59dc74..5a35115bc 100644 --- a/test/unbalanced/test_sinkhorn.py +++ b/test/unbalanced/test_sinkhorn.py @@ -140,6 +140,51 @@ def test_unbalanced_warmstart(nx, method, reg_type): np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5) +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) +def test_unbalanced_reference_measure(nx, method, reg_type): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + b = ot.utils.unif(n) + M = ot.dist(x, x) + a, b, M = nx.from_numpy(a, b, M) + + epsilon = 1. + reg_m = 1. + + G0, log0 = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, c=None, log=True + ) + loss0 = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, reg_type=reg_type, c=None + ) + + if reg_type == "kl": + c = a[:, None] * b[None, :] + elif reg_type == "entropy": + c = nx.ones(M.shape, type_as=M) + + G, log = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, c=c, log=True + ) + loss = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, c=c + ) + + np.testing.assert_allclose( + nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) + np.testing.assert_allclose( + nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05) + np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05) + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5) + + @pytest.mark.parametrize("method, log", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], [True, False])) def test_sinkhorn_unbalanced2(nx, method, log): n = 100 From e804d0a4eedba208e12919ffcfcea6505893cc92 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Mon, 9 Sep 2024 14:28:56 +0200 Subject: [PATCH 13/17] improve codecov for lbfgsb --- ot/unbalanced/_lbfgs.py | 19 +------------------ test/unbalanced/test_sinkhorn.py | 2 +- 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index c82f6fe5d..e7aae7f86 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -263,25 +263,8 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', a, b, M = nx.to_numpy(a, b, M) G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) - - # wrap the callable function to handle numpy arrays - if isinstance(reg_div, tuple): - f0, df0 = reg_div - try: - f0(G0) - df0(G0) - except BaseException: - warnings.warn("The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead") - - def f(x): - return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0))) - - def df(x): - return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0))) - - reg_div = (f, df) - reg_m1, reg_m2 = get_parameter_pair(reg_m) + _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf), diff --git a/test/unbalanced/test_sinkhorn.py b/test/unbalanced/test_sinkhorn.py index 5a35115bc..2240f2769 100644 --- a/test/unbalanced/test_sinkhorn.py +++ b/test/unbalanced/test_sinkhorn.py @@ -362,7 +362,7 @@ def sinkhorn2(returnCost): np.testing.assert_raises(ValueError, sinkhorn2, "invalid_returnCost") -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"]) def test_unbalanced_barycenter(nx, method): # test generalized sinkhorn for unbalanced OT barycenter n = 100 From c8c10e31765da554aa05a15806362c3a7f556b5f Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Mon, 9 Sep 2024 14:32:34 +0200 Subject: [PATCH 14/17] remove unused impor --- ot/unbalanced/_lbfgs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index e7aae7f86..570500f52 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -9,7 +9,6 @@ # # License: MIT License -import warnings import numpy as np from scipy.optimize import minimize, Bounds From 99d78e92ce28bdc11913a6a1ce24ddfdfbb23f32 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Tue, 10 Sep 2024 15:42:48 +0200 Subject: [PATCH 15/17] fix bug in example --- ot/unbalanced/_lbfgs.py | 30 ++++++++++++++++++++++++++---- test/test_solvers.py | 2 +- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index 570500f52..c1758c859 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -9,6 +9,7 @@ # # License: MIT License +import warnings import numpy as np from scipy.optimize import minimize, Bounds @@ -241,12 +242,34 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ - reg_div, regm_div = reg_div.lower(), regm_div.lower() - if reg_div not in ["entropy", "kl", "l2"]: - raise ValueError("Unknown reg_div = {}. Must be either 'entropy', 'kl' or 'l2'".format(reg_div)) + # wrap the callable function to handle numpy arrays + if isinstance(reg_div, tuple): + f0, df0 = reg_div + try: + f0(G0) + df0(G0) + except BaseException: + warnings.warn("The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead") + + def f(x): + return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0))) + + def df(x): + return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0))) + + reg_div = (f, df) + + else: + reg_div = reg_div.lower() + if reg_div not in ["entropy", "kl", "l2"]: + raise ValueError("Unknown reg_div = {}. Must be either 'entropy', 'kl' or 'l2', or a tuple".format(reg_div)) + + regm_div = regm_div.lower() if regm_div not in ["kl", "l2", "tv"]: raise ValueError("Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div)) + reg_m1, reg_m2 = get_parameter_pair(reg_m) + M, a, b = list_to_array(M, a, b) nx = get_backend(M, a, b) M0 = M @@ -262,7 +285,6 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', a, b, M = nx.to_numpy(a, b, M) G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) - reg_m1, reg_m2 = get_parameter_pair(reg_m) _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) diff --git a/test/test_solvers.py b/test/test_solvers.py index 863c35f76..16e6df295 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -16,7 +16,7 @@ lst_reg = [None, 1] -lst_reg_type = ['KL', 'entropy', 'L2'] +lst_reg_type = ['KL', 'entropy', 'L2', 'tuple'] lst_unbalanced = [None, 0.9] lst_unbalanced_type = ['KL', 'L2', 'TV'] From e1eecc92df7d987cca003cbc9870f7c2b7cb6567 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Tue, 10 Sep 2024 18:30:20 +0200 Subject: [PATCH 16/17] add documentation and tests --- ot/unbalanced/_lbfgs.py | 14 +++++++-- ot/unbalanced/_mm.py | 8 +++++ ot/unbalanced/_sinkhorn.py | 38 +++++++++++++++++------- test/unbalanced/test_lbfgs.py | 35 ++++++++++++++++++++++ test/unbalanced/test_mm.py | 37 +++++++++++++++++++++++ test/unbalanced/test_sinkhorn.py | 51 +++++++++++++++++++++++++------- 6 files changed, 158 insertions(+), 25 deletions(-) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index c1758c859..062a472f9 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -175,8 +175,12 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` + If `a` is an empty list or array ([]), + then `a` is set to uniform distribution. b : array-like (dim_b,) Unnormalized histogram of dimension `dim_b` + If `b` is an empty list or array ([]), + then `b` is set to uniform distribution. M : array-like (dim_a, dim_b) loss matrix reg: float @@ -276,14 +280,14 @@ def df(x): dim_a, dim_b = M.shape - if len(a) == 0 or a is None: + if len(a) == 0: a = nx.ones(dim_a, type_as=M) / dim_a - if len(b) == 0 or b is None: + if len(b) == 0: b = nx.ones(dim_b, type_as=M) / dim_b # convert to numpy a, b, M = nx.to_numpy(a, b, M) - G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) + G0 = a[:, None] * b[None, :] if G0 is None else nx.to_numpy(G0) c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) @@ -335,8 +339,12 @@ def lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` + If `a` is an empty list or array ([]), + then `a` is set to uniform distribution. b : array-like (dim_b,) Unnormalized histogram of dimension `dim_b` + If `b` is an empty list or array ([]), + then `b` is set to uniform distribution. M : array-like (dim_a, dim_b) loss matrix reg: float diff --git a/ot/unbalanced/_mm.py b/ot/unbalanced/_mm.py index 73f6cd9f5..b22f234d1 100644 --- a/ot/unbalanced/_mm.py +++ b/ot/unbalanced/_mm.py @@ -43,8 +43,12 @@ def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1 ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` + If `a` is an empty list or array ([]), + then `a` is set to uniform distribution. b : array-like (dim_b,) Unnormalized histogram of dimension `dim_b` + If `b` is an empty list or array ([]), + then `b` is set to uniform distribution. M : array-like (dim_a, dim_b) loss matrix reg_m: float or indexable object of length 1 or 2 @@ -209,8 +213,12 @@ def mm_unbalanced2(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, returnCost= ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` + If `a` is an empty list or array ([]), + then `a` is set to uniform distribution. b : array-like (dim_b,) Unnormalized histogram of dimension `dim_b` + If `b` is an empty list or array ([]), + then `b` is set to uniform distribution. M : array-like (dim_a, dim_b) loss matrix reg_m: float or indexable object of length 1 or 2 diff --git a/ot/unbalanced/_sinkhorn.py b/ot/unbalanced/_sinkhorn.py index 7201c9750..37e85253b 100644 --- a/ot/unbalanced/_sinkhorn.py +++ b/ot/unbalanced/_sinkhorn.py @@ -49,9 +49,13 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : array-like (dim_b,) or array-like (dim_b, n_hists) + If `a` is an empty list or array ([]), + then `a` is set to uniform distribution. + b : array-like (dim_b,) One or multiple unnormalized histograms of dimension `dim_b`. - If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` + If `b` is an empty list or array ([]), + then `b` is set to uniform distribution. + If many, compute all the OT costs :math:`(\mathbf{a}, \mathbf{b}_i)_i` M : array-like (dim_a, dim_b) loss matrix reg : float @@ -202,9 +206,13 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : array-like (dim_b,) or array-like (dim_b, n_hists) + If `a` is an empty list or array ([]), + then `a` is set to uniform distribution. + b : array-like (dim_b,) One or multiple unnormalized histograms of dimension `dim_b`. - If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` + If `b` is an empty list or array ([]), + then `b` is set to uniform distribution. + If many, compute all the OT costs :math:`(\mathbf{a}, \mathbf{b}_i)_i` M : array-like (dim_a, dim_b) loss matrix reg : float @@ -381,9 +389,13 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : array-like (dim_b,) or array-like (dim_b, n_hists) - One or multiple unnormalized histograms of dimension `dim_b` - If many, compute all the OT distances (a, b_i) + If `a` is an empty list or array ([]), + then `a` is set to uniform distribution. + b : array-like (dim_b,) + One or multiple unnormalized histograms of dimension `dim_b`. + If `b` is an empty list or array ([]), + then `b` is set to uniform distribution. + If many, compute all the OT costs :math:`(\mathbf{a}, \mathbf{b}_i)_i` M : array-like (dim_a, dim_b) loss matrix reg : float @@ -613,9 +625,13 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : array-like (dim_b,) or array-like (dim_b, n_hists) + If `a` is an empty list or array ([]), + then `a` is set to uniform distribution. + b : array-like (dim_b,) One or multiple unnormalized histograms of dimension `dim_b`. - If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` + If `b` is an empty list or array ([]), + then `b` is set to uniform distribution. + If many, compute all the OT costs :math:`(\mathbf{a}, \mathbf{b}_i)_i` M : array-like (dim_a, dim_b) loss matrix reg : float @@ -704,9 +720,9 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, dim_a, dim_b = M.shape - if len(a) == 0 or a is None: + if len(a) == 0: a = nx.ones(dim_a, type_as=M) / dim_a - if len(b) == 0 or b is None: + if len(b) == 0: b = nx.ones(dim_b, type_as=M) / dim_b if len(b.shape) > 1: diff --git a/test/unbalanced/test_lbfgs.py b/test/unbalanced/test_lbfgs.py index 4f3ae8177..4b33fc526 100644 --- a/test/unbalanced/test_lbfgs.py +++ b/test/unbalanced/test_lbfgs.py @@ -126,6 +126,41 @@ def test_lbfgsb_reference_measure(nx, reg_div, regm_div, returnCost): np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-06) +@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2', 'tv'], ['linear', 'total'])) +def test_lbfgsb_marginals(nx, reg_div, regm_div, returnCost): + + np.random.seed(42) + + xs = np.random.randn(5, 2) + xt = np.random.randn(6, 2) + M = ot.dist(xs, xt) + a = ot.unif(5) + b = ot.unif(6) + + a, b, M = nx.from_numpy(a, b, M) + + G, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + loss, _ = ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=1, reg_m=10, + reg_div=reg_div, regm_div=regm_div, + returnCost=returnCost, log=True, verbose=False) + + a_empty, b_empty = np.array([]), np.array([]) + a_empty, b_empty = nx.from_numpy(a_empty, b_empty) + + G0, _ = ot.unbalanced.lbfgsb_unbalanced(a_empty, b_empty, M, reg=1, reg_m=10, + reg_div=reg_div, regm_div=regm_div, + log=True, verbose=False) + + loss0, _ = ot.unbalanced.lbfgsb_unbalanced2(a_empty, b_empty, M, reg=1, reg_m=10, + reg_div=reg_div, regm_div=regm_div, + returnCost=returnCost, log=True, verbose=False) + + np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-06) + + def test_lbfgsb_wrong_divergence(nx): n = 100 diff --git a/test/unbalanced/test_mm.py b/test/unbalanced/test_mm.py index c68fc1a93..ea9f00869 100644 --- a/test/unbalanced/test_mm.py +++ b/test/unbalanced/test_mm.py @@ -135,6 +135,43 @@ def test_mm_reference_measure(nx, div): np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) +@pytest.mark.parametrize("div", ["kl", "l2"]) +def test_mm_marginals(nx, div): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a_np = ot.utils.unif(n) + b_np = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + a, b, M = nx.from_numpy(a_np, b_np, M) + + reg = 1e-2 + reg_m = 100 + + G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=None, reg=reg, + div=div, verbose=False, log=True) + loss_0 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=None, reg=reg, + div=div, verbose=True) + loss_0 = nx.to_numpy(loss_0) + + a_empty, b_empty = np.array([]), np.array([]) + a_empty, b_empty = nx.from_numpy(a_empty, b_empty) + + G1, _ = ot.unbalanced.mm_unbalanced(a_empty, b_empty, M, reg_m=reg_m, + reg=reg, div=div, + verbose=False, log=True) + loss_1 = ot.unbalanced.mm_unbalanced2(a_empty, b_empty, M, reg_m=reg_m, + reg=reg, div=div, verbose=True) + loss_1 = nx.to_numpy(loss_1) + + np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) + np.testing.assert_allclose(loss_0, loss_1, atol=1e-5) + + def test_mm_wrong_divergence(nx): n = 100 diff --git a/test/unbalanced/test_sinkhorn.py b/test/unbalanced/test_sinkhorn.py index 2240f2769..595f9ba97 100644 --- a/test/unbalanced/test_sinkhorn.py +++ b/test/unbalanced/test_sinkhorn.py @@ -62,20 +62,49 @@ def test_unbalanced_convergence(nx, method, reg_type): # check if sinkhorn_unbalanced2 returns the correct loss np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), loss, atol=1e-5) - # check in case no histogram is provided - M_np = nx.to_numpy(M) - a_np, b_np = np.array([]), np.array([]) - a, b = nx.from_numpy(a_np, b_np) - G = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, - method=method, reg_type=reg_type, verbose=True +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) +def test_unbalanced_marginals(nx, method, reg_type): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + b = ot.utils.unif(n) + M = ot.dist(x, x) + a, b, M = nx.from_numpy(a, b, M) + + epsilon = 1. + reg_m = 1. + + G0, log0 = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, log=True + ) + loss0 = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, reg_type=reg_type, ) - G_np = ot.unbalanced.sinkhorn_unbalanced( - a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, - method=method, reg_type=reg_type, verbose=True + + # check in case no histogram is provided or histogram is None + a_empty, b_empty = np.array([]), np.array([]) + a_empty, b_empty = nx.from_numpy(a_empty, b_empty) + + G_empty, log_empty = ot.unbalanced.sinkhorn_unbalanced( + a_empty, b_empty, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type, log=True + ) + loss_empty = ot.unbalanced.sinkhorn_unbalanced2( + a_empty, b_empty, M, reg=epsilon, reg_m=reg_m, method=method, + reg_type=reg_type ) - np.testing.assert_allclose(G_np, nx.to_numpy(G)) + + np.testing.assert_allclose( + nx.to_numpy(log_empty["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) + np.testing.assert_allclose( + nx.to_numpy(log_empty["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05) + np.testing.assert_allclose(nx.to_numpy(G_empty), nx.to_numpy(G0), atol=1e-05) + np.testing.assert_allclose(nx.to_numpy(loss_empty), nx.to_numpy(loss0), atol=1e-5) @pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) From ad2e6f8d0cc095fbfc7541143c54c7f287cab722 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Tue, 10 Sep 2024 22:28:18 +0200 Subject: [PATCH 17/17] Update RELEASES.md --- RELEASES.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index cc18cc91b..277af7847 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,6 +10,8 @@ - Improved `ot.plot.plot1D_mat` (PR #649) - Added `nx.det` (PR #649) - `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649) +- restructure `ot.unbalanced` module (PR #658) +- add `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658) #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648)