diff --git a/RELEASES.md b/RELEASES.md index 6e2b45a6d..821432548 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -2,6 +2,9 @@ ## 0.9.5dev +#### Breaking change +- Custom functions provided as parameter `line_search` to `ot.optim.generic_conditional_gradient` must now have the signature `line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs)`, adding as input `df_G` the gradient of the regularizer evaluated at the transport plan `G`. This change aims at improving speed of solvers having quadratic polynomial functions as regularizer such as the Gromov-Wassertein loss (PR #663). + #### New features - Added feature `mass=True` for `nx.kl_div` (PR #654) - Implemented Gaussian Mixture Model OT `ot.gmm` (PR #649) @@ -13,6 +16,8 @@ - Restructured `ot.unbalanced` module (PR #658) - Added `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658) - Implemented Fused unbalanced Gromov-Wasserstein and unbalanced Co-Optimal Transport (PR #677) +- Notes before depreciating partial Gromov-Wasserstein function in `ot.partial` moved to ot.gromov (PR #663) +- Create `ot.gromov._partial` add new features `loss_fun = "kl_loss"` and `symmetry=False` to all solvers while increasing speed + updating adequatly `ot.solvers` (PR #663) - Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676) #### Closed issues diff --git a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py index ac4194ca0..5c85a5a22 100755 --- a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py +++ b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py @@ -125,10 +125,10 @@ # transport 100% of the mass print('------m = 1') m = 1 -res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True) -res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, - m=m, log=True, - verbose=True) +res0, log0 = ot.gromov.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True) +res, log = ot.gromov.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, + m=m, log=True, + verbose=True) print('Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist'])) print('Entropic Wasserstein distance (m = 1): ' + str(log['partial_gw_dist'])) @@ -146,11 +146,11 @@ # transport 2/3 of the mass print('------m = 2/3') m = 2 / 3 -res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True, - verbose=True) -res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, - m=m, log=True, - verbose=True) +res0, log0 = ot.gromov.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True, + verbose=True) +res, log = ot.gromov.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, + m=m, log=True, + verbose=True) print('Partial Wasserstein distance (m = 2/3): ' + str(log0['partial_gw_dist'])) diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py index 6d3f56d8b..2efd69ccd 100644 --- a/ot/gromov/__init__.py +++ b/ot/gromov/__init__.py @@ -72,33 +72,55 @@ unbalanced_co_optimal_transport2, fused_unbalanced_across_spaces_divergence) +from ._partial import (partial_gromov_wasserstein, + partial_gromov_wasserstein2, + solve_partial_gromov_linesearch, + entropic_partial_gromov_wasserstein, + entropic_partial_gromov_wasserstein2) + + __all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'init_matrix_semirelaxed', 'semirelaxed_init_plan', 'update_barycenter_structure', 'update_barycenter_feature', 'div_between_product', 'div_to_product', 'fused_unbalanced_across_spaces_cost', 'uot_cost_matrix', 'uot_parameters_and_measures', - 'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein', - 'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters', - 'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2', - 'BAPG_gromov_wasserstein', 'BAPG_gromov_wasserstein2', - 'entropic_gromov_barycenters', 'entropic_fused_gromov_wasserstein', + 'gromov_wasserstein', 'gromov_wasserstein2', + 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', + 'solve_gromov_linesearch', 'gromov_barycenters', + 'fgw_barycenters', 'entropic_gromov_wasserstein', + 'entropic_gromov_wasserstein2', 'BAPG_gromov_wasserstein', + 'BAPG_gromov_wasserstein2', 'entropic_gromov_barycenters', + 'entropic_fused_gromov_wasserstein', 'entropic_fused_gromov_wasserstein2', 'BAPG_fused_gromov_wasserstein', 'BAPG_fused_gromov_wasserstein2', 'entropic_fused_gromov_barycenters', - 'GW_distance_estimation', 'pointwise_gromov_wasserstein', 'sampled_gromov_wasserstein', + 'GW_distance_estimation', 'pointwise_gromov_wasserstein', + 'sampled_gromov_wasserstein', 'semirelaxed_gromov_wasserstein', 'semirelaxed_gromov_wasserstein2', - 'semirelaxed_fused_gromov_wasserstein', 'semirelaxed_fused_gromov_wasserstein2', - 'solve_semirelaxed_gromov_linesearch', 'entropic_semirelaxed_gromov_wasserstein', - 'entropic_semirelaxed_gromov_wasserstein2', 'entropic_semirelaxed_fused_gromov_wasserstein', + 'semirelaxed_fused_gromov_wasserstein', + 'semirelaxed_fused_gromov_wasserstein2', + 'solve_semirelaxed_gromov_linesearch', + 'entropic_semirelaxed_gromov_wasserstein', + 'entropic_semirelaxed_gromov_wasserstein2', + 'entropic_semirelaxed_fused_gromov_wasserstein', 'entropic_semirelaxed_fused_gromov_wasserstein2', 'semirelaxed_fgw_barycenters', 'semirelaxed_gromov_barycenters', 'gromov_wasserstein_dictionary_learning', - 'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning', - 'fused_gromov_wasserstein_linear_unmixing', 'lowrank_gromov_wasserstein_samples', - 'quantized_fused_gromov_wasserstein_partitioned', 'get_graph_partition', - 'get_graph_representants', 'format_partitioned_graph', - 'quantized_fused_gromov_wasserstein', 'get_partition_and_representants_samples', - 'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples', - 'fused_unbalanced_gromov_wasserstein', 'fused_unbalanced_gromov_wasserstein2', - 'unbalanced_co_optimal_transport', 'unbalanced_co_optimal_transport2', - 'fused_unbalanced_across_spaces_divergence' + 'gromov_wasserstein_linear_unmixing', + 'fused_gromov_wasserstein_dictionary_learning', + 'fused_gromov_wasserstein_linear_unmixing', + 'lowrank_gromov_wasserstein_samples', + 'quantized_fused_gromov_wasserstein_partitioned', + 'get_graph_partition', 'get_graph_representants', + 'format_partitioned_graph', 'quantized_fused_gromov_wasserstein', + 'get_partition_and_representants_samples', 'format_partitioned_samples', + 'quantized_fused_gromov_wasserstein_samples', + 'fused_unbalanced_gromov_wasserstein', + 'fused_unbalanced_gromov_wasserstein2', + 'unbalanced_co_optimal_transport', + 'unbalanced_co_optimal_transport2', + 'fused_unbalanced_across_spaces_divergence', + 'partial_gromov_wasserstein', 'partial_gromov_wasserstein2', + 'solve_partial_gromov_linesearch', + 'entropic_partial_gromov_wasserstein', + 'entropic_partial_gromov_wasserstein2' ] diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 1cbc98909..806e691e1 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -167,10 +167,10 @@ def df(G): return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) if armijo: - def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) else: - def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, symmetric=symmetric, **kwargs) if not nx.is_floating_point(C10): @@ -475,11 +475,12 @@ def df(G): return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) if armijo: - def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) else: - def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, symmetric=symmetric, **kwargs) + if not nx.is_floating_point(M0): warnings.warn( "Input feature matrix consists of integer. The transport plan will be " @@ -625,9 +626,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) + elif loss_fun == 'kl_loss': gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + if isinstance(alpha, int) or isinstance(alpha, float): fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M), (log_fgw['u'] - nx.mean(log_fgw['u']), diff --git a/ot/gromov/_partial.py b/ot/gromov/_partial.py new file mode 100644 index 000000000..f1840655c --- /dev/null +++ b/ot/gromov/_partial.py @@ -0,0 +1,821 @@ +# -*- coding: utf-8 -*- +""" +Partial (Fused) Gromov-Wasserstein solvers. +""" + +# Author: Laetitia Chapel +# Cédric Vincent-Cuaz +# Yikun Bai < yikun.bai@vanderbilt.edu > +# +# License: MIT License + + +from ..utils import list_to_array, unif +from ..backend import get_backend, NumpyBackend +from ..partial import entropic_partial_wasserstein +from ._utils import _transform_matrix, gwloss, gwggrad +from ..optim import partial_cg, solve_1d_linesearch_quad + +import numpy as np +import warnings + + +def partial_gromov_wasserstein( + C1, C2, p=None, q=None, m=None, loss_fun='square_loss', nb_dummies=1, + G0=None, thres=1, numItermax=1e4, tol=1e-8, symmetric=None, warn=True, + log=False, verbose=False, **kwargs): + r""" + Returns the Partial Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` + and :math:`(\mathbf{C_2}, \mathbf{q})`. + + The function solves the following optimization problem using Conditional Gradient: + + .. math:: + \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + \mathbf{1}^T \mathbf{T}^T \mathbf{1} = m &\leq \min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\} + + where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space. + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space. + - :math:`\mathbf{p}`: Distribution in the source space. + - :math:`\mathbf{q}`: Distribution in the target space. + - `m` is the amount of mass to be transported + - `L`: Loss function to account for the misfit between the similarity matrices. + + The formulation of the problem has been proposed in + :ref:`[29] ` + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + .. note:: All computations in the conjugate gradient solver are done with + numpy to limit memory overhead. + .. note:: This function will cast the computed transport plan to the data + type of the provided input :math:`\mathbf{C}_1`. Casting to an integer + tensor might result in a loss of precision. If this behaviour is + unwanted, please make sure to provide a floating point input. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric costfr matrix in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + m : float, optional + Amount of mass to be transported + (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + loss_fun : str, optional + Loss function used for the solver either 'square_loss' or 'kl_loss'. + nb_dummies : int, optional + Number of dummy points to add (avoid instabilities in the EMD solver) + G0 : array-like, shape (ns, nt), optional + Initialization of the transportation matrix + thres : float, optional + quantile of the gradient matrix to populate the cost matrix when 0 + (default: 1) + numItermax : int, optional + Max number of iterations + tol : float, optional + tolerance for stopping iterations + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + warn: bool, optional. + Whether to raise a warning when EMD did not converge. + log : bool, optional + return log if True + verbose : bool, optional + Print information along iterations + **kwargs : dict + parameters can be directly passed to the emd solver + + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Optimal transport matrix between the two spaces. + + log : dict + Convergence information and loss. + + Examples + -------- + >>> from ot.gromov import partial_gromov_wasserstein + >>> import scipy as sp + >>> a = np.array([0.25] * 4) + >>> b = np.array([0.25] * 4) + >>> x = np.array([1,2,100,200]).reshape((-1,1)) + >>> y = np.array([3,2,98,199]).reshape((-1,1)) + >>> C1 = sp.spatial.distance.cdist(x, x) + >>> C2 = sp.spatial.distance.cdist(y, y) + >>> np.round(partial_gromov_wasserstein(C1, C2, a, b),2) + array([[0. , 0.25, 0. , 0. ], + [0.25, 0. , 0. , 0. ], + [0. , 0. , 0.25, 0. ], + [0. , 0. , 0. , 0.25]]) + >>> np.round(partial_gromov_wasserstein(C1, C2, a, b, m=0.25),2) + array([[0. , 0. , 0. , 0. ], + [0. , 0. , 0. , 0. ], + [0. , 0. , 0.25, 0. ], + [0. , 0. , 0. , 0. ]]) + + + .. _references-partial-gromov-wasserstein: + References + ---------- + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. + + """ + arr = [C1, C2] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(C1.shape[0], type_as=C1) + if q is not None: + arr.append(list_to_array(q)) + else: + q = unif(C2.shape[0], type_as=C1) + if G0 is not None: + G0_ = G0 + arr.append(G0) + + nx = get_backend(*arr) + p0, q0, C10, C20 = p, q, C1, C2 + + p = nx.to_numpy(p0) + q = nx.to_numpy(q0) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + if symmetric is None: + symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) + + if m is None: + m = min(np.sum(p), np.sum(q)) + elif m < 0: + raise ValueError("Problem infeasible. Parameter m should be greater" + " than 0.") + elif m > min(np.sum(p), np.sum(q)): + raise ValueError("Problem infeasible. Parameter m should lower or" + " equal than min(|a|_1, |b|_1).") + + if G0 is None: + G0 = np.outer(p, q) * m / (np.sum(p) * np.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + assert np.all(G0.sum(1) <= p) + assert np.all(G0.sum(0) <= q) + + q_extended = np.append(q, [(np.sum(p) - m) / nb_dummies] * nb_dummies) + p_extended = np.append(p, [(np.sum(q) - m) / nb_dummies] * nb_dummies) + + # cg for GW is implemented using numpy on CPU + np_ = NumpyBackend() + + fC1, fC2, hC1, hC2 = _transform_matrix(C1, C2, loss_fun, np_) + fC2t = fC2.T + if not symmetric: + fC1t, hC1t, hC2t = fC1.T, hC1.T, hC2.T + + ones_p = np_.ones(p.shape[0], type_as=p) + ones_q = np_.ones(q.shape[0], type_as=q) + + def f(G): + pG = G.sum(1) + qG = G.sum(0) + constC1 = np.outer(np.dot(fC1, pG), ones_q) + constC2 = np.outer(ones_p, np.dot(qG, fC2t)) + return gwloss(constC1 + constC2, hC1, hC2, G, np_) + + if symmetric: + def df(G): + pG = G.sum(1) + qG = G.sum(0) + constC1 = np.outer(np.dot(fC1, pG), ones_q) + constC2 = np.outer(ones_p, np.dot(qG, fC2t)) + return gwggrad(constC1 + constC2, hC1, hC2, G, np_) + else: + + def df(G): + pG = G.sum(1) + qG = G.sum(0) + constC1 = np.outer(np.dot(fC1, pG), ones_q) + constC2 = np.outer(ones_p, np.dot(qG, fC2t)) + constC1t = np.outer(np.dot(fC1t, pG), ones_q) + constC2t = np.outer(ones_p, np.dot(qG, fC2)) + + return 0.5 * ( + gwggrad(constC1 + constC2, hC1, hC2, G, np_) + + gwggrad(constC1t + constC2t, hC1t, hC2t, G, np_)) + + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): + df_Gc = df(deltaG + G) + return solve_partial_gromov_linesearch( + G, deltaG, cost_G, df_G, df_Gc, M=0., reg=1., nx=np_, **kwargs) + + if not nx.is_floating_point(C10): + warnings.warn( + "Input structure matrix consists of integers. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "structure matrix consists of floating point elements.", + stacklevel=2 + ) + + if log: + res, log = partial_cg(p, q, p_extended, q_extended, 0., 1., f, df, G0, + line_search, log=True, numItermax=numItermax, + stopThr=tol, stopThr2=0., warn=warn, **kwargs) + log['partial_gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) + return nx.from_numpy(res, type_as=C10), log + else: + return nx.from_numpy( + partial_cg(p, q, p_extended, q_extended, 0., 1., f, df, G0, + line_search, log=False, numItermax=numItermax, + stopThr=tol, stopThr2=0., **kwargs), type_as=C10) + + +def partial_gromov_wasserstein2( + C1, C2, p=None, q=None, m=None, loss_fun='square_loss', nb_dummies=1, G0=None, + thres=1, numItermax=1e4, tol=1e-7, symmetric=None, warn=False, log=False, + verbose=False, **kwargs): + r""" + Returns the Partial Gromov-Wasserstein discrepancy between + :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`. + + The function solves the following optimization problem using Conditional Gradient: + + .. math:: + \mathbf{PGW} = \mathop{\min}_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + \mathbf{1}^T \mathbf{T}^T \mathbf{1} = m &\leq \min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\} + + where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space. + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space. + - :math:`\mathbf{p}`: Distribution in the source space. + - :math:`\mathbf{q}`: Distribution in the target space. + - `m` is the amount of mass to be transported + - `L`: Loss function to account for the misfit between the similarity matrices. + + + The formulation of the problem has been proposed in + :ref:`[29] ` + + Note that when using backends, this loss function is differentiable wrt the + matrices (C1, C2). + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + .. note:: All computations in the conjugate gradient solver are done with + numpy to limit memory overhead. + .. note:: This function will cast the computed transport plan to the data + type of the provided input :math:`\mathbf{C}_1`. Casting to an integer + tensor might result in a loss of precision. If this behaviour is + unwanted, please make sure to provide a floating point input. + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric cost matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + m : float, optional + Amount of mass to be transported + (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + loss_fun : str, optional + Loss function used for the solver either 'square_loss' or 'kl_loss'. + nb_dummies : int, optional + Number of dummy points to add (avoid instabilities in the EMD solver) + G0 : ndarray, shape (ns, nt), optional + Initialization of the transportation matrix + thres : float, optional + quantile of the gradient matrix to populate the cost matrix when 0 + (default: 1) + numItermax : int, optional + Max number of iterations + tol : float, optional + tolerance for stopping iterations + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + warn: bool, optional. + Whether to raise a warning when EMD did not converge. + log : bool, optional + return log if True + verbose : bool, optional + Print information along iterations + **kwargs : dict + parameters can be directly passed to the emd solver + + + .. warning:: + When dealing with a large number of points, the EMD solver may face + some instabilities, especially when the mass associated to the dummy + point is large. To avoid them, increase the number of dummy points + (allows a smoother repartition of the mass over the points). + + + Returns + ------- + partial_gw_dist : float + partial GW discrepancy + log : dict + log dictionary returned only if `log` is `True` + + + Examples + -------- + >>> from ot.gromov import partial_gromov_wasserstein2 + >>> import scipy as sp + >>> a = np.array([0.25] * 4) + >>> b = np.array([0.25] * 4) + >>> x = np.array([1,2,100,200]).reshape((-1,1)) + >>> y = np.array([3,2,98,199]).reshape((-1,1)) + >>> C1 = sp.spatial.distance.cdist(x, x) + >>> C2 = sp.spatial.distance.cdist(y, y) + >>> np.round(partial_gromov_wasserstein2(C1, C2, a, b),2) + 3.38 + >>> np.round(partial_gromov_wasserstein2(C1, C2, a, b, m=0.25),2) + 0.0 + + + .. _references-partial-gromov-wasserstein2: + References + ---------- + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. + + """ + # simple get_backend as the full one will be handled in gromov_wasserstein + nx = get_backend(C1, C2) + + # init marginals if set as None + if p is None: + p = unif(C1.shape[0], type_as=C1) + if q is None: + q = unif(C2.shape[0], type_as=C1) + + T, log_pgw = partial_gromov_wasserstein( + C1, C2, p, q, m, loss_fun, nb_dummies, G0, thres, + numItermax, tol, symmetric, warn, True, verbose, **kwargs) + + log_pgw['T'] = T + pgw = log_pgw['partial_gw_dist'] + + if loss_fun == 'square_loss': + gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) + gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) + elif loss_fun == 'kl_loss': + gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) + gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + + pgw = nx.set_gradients(pgw, (C1, C2), (gC1, gC2)) + + if log: + return pgw, log_pgw + else: + return pgw + + +def solve_partial_gromov_linesearch( + G, deltaG, cost_G, df_G, df_Gc, M, reg, alpha_min=None, alpha_max=None, + nx=None, **kwargs): + """ + Solve the linesearch in the FW iterations of partial (F)GW following eq.5 of :ref:`[29]`. + + Parameters + ---------- + + G : array-like, shape(ns,nt) + The transport map at a given iteration of the FW + deltaG : array-like (ns,nt) + Difference between the optimal map `Gc` found by linearization in the + FW algorithm and the value at a given iteration + cost_G : float + Value of the cost at `G` + df_G : array-like (ns,nt) + Gradient of the GW cost at `G` + df_Gc : array-like (ns,nt) + Gradient of the GW cost at `Gc` + M : array-like (ns,nt) + Cost matrix between the features. + reg : float + Regularization parameter. + alpha_min : float, optional + Minimum value for alpha + alpha_max : float, optional + Maximum value for alpha + nx : backend, optional + If let to its default value None, a backend test will be conducted. + + Returns + ------- + alpha : float + The optimal step size of the FW + fc : int + nb of function call. Useless here + cost_G : float + The value of the cost for the next iteration + df_G : array-like (ns,nt) + Updated gradient of the GW cost + + References + ---------- + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. + + """ + if nx is None: + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(G, deltaG, df_G, df_Gc) + else: + nx = get_backend(G, deltaG, df_G, df_Gc, M) + + df_deltaG = df_Gc - df_G + cost_deltaG = 0.5 * nx.sum(df_deltaG * deltaG) + + a = reg * cost_deltaG + # formula to check for partial FGW + b = nx.sum(M * deltaG) + reg * nx.sum(df_G * deltaG) + + alpha = solve_1d_linesearch_quad(a, b) + if alpha_min is not None or alpha_max is not None: + alpha = np.clip(alpha, alpha_min, alpha_max) + + # the new cost is deduced from the line search quadratic function + cost_G = cost_G + a * (alpha ** 2) + b * alpha + + # update the gradient for next cg iteration + df_G = df_G + alpha * df_deltaG + return alpha, 1, cost_G, df_G + + +def entropic_partial_gromov_wasserstein( + C1, C2, p=None, q=None, reg=1., m=None, loss_fun='square_loss', G0=None, + numItermax=1000, tol=1e-7, symmetric=None, log=False, verbose=False): + r""" + Returns the partial Gromov-Wasserstein transport between + :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_{\gamma} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot + \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) + + .. math:: + s.t. \ \gamma &\geq 0 + + \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \mathbf{1}^T \gamma^T \mathbf{1} = m + &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + + where : + + - :math:`\mathbf{C_1}` is the metric cost matrix in the source space + - :math:`\mathbf{C_2}` is the metric cost matrix in the target space + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights + - `L`: quadratic loss function + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - `m` is the amount of mass to be transported + + The formulation of the GW problem has been proposed in + :ref:`[12] ` and the + partial GW in :ref:`[29] ` + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + reg: float, optional. Default is 1. + entropic regularization parameter + m : float, optional + Amount of mass to be transported (default: + :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + loss_fun : str, optional + Loss function used for the solver either 'square_loss' or 'kl_loss'. + G0 : array-like, shape (ns, nt), optional + Initialization of the transportation matrix + numItermax : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + log : bool, optional + return log if True + verbose : bool, optional + Print information along iterations + + Examples + -------- + >>> from ot.gromov import entropic_partial_gromov_wasserstein + >>> import scipy as sp + >>> a = np.array([0.25] * 4) + >>> b = np.array([0.25] * 4) + >>> x = np.array([1,2,100,200]).reshape((-1,1)) + >>> y = np.array([3,2,98,199]).reshape((-1,1)) + >>> C1 = sp.spatial.distance.cdist(x, x) + >>> C2 = sp.spatial.distance.cdist(y, y) + >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 1e2), 2) + array([[0.12, 0.13, 0. , 0. ], + [0.13, 0.12, 0. , 0. ], + [0. , 0. , 0.25, 0. ], + [0. , 0. , 0. , 0.25]]) + >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 1e2,0.25), 2) + array([[0.02, 0.03, 0. , 0.03], + [0.03, 0.03, 0. , 0.03], + [0. , 0. , 0.03, 0. ], + [0.02, 0.02, 0. , 0.03]]) + + Returns + ------- + :math: `gamma` : (dim_a, dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + + + .. _references-entropic-partial-gromov-wasserstein: + References + ---------- + .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. + + See Also + -------- + ot.partial.partial_gromov_wasserstein: exact Partial Gromov-Wasserstein + """ + + arr = [C1, C2, G0] + if p is not None: + p = list_to_array(p) + arr.append(p) + if q is not None: + q = list_to_array(q) + arr.append(q) + + nx = get_backend(*arr) + + if p is None: + p = nx.ones(C1.shape[0], type_as=C1) / C1.shape[0] + if q is None: + q = nx.ones(C2.shape[0], type_as=C2) / C2.shape[0] + + if m is None: + m = min(nx.sum(p), nx.sum(q)) + elif m < 0: + raise ValueError("Problem infeasible. Parameter m should be greater" + " than 0.") + elif m > min(nx.sum(p), nx.sum(q)): + raise ValueError("Problem infeasible. Parameter m should lower or" + " equal than min(|a|_1, |b|_1).") + + if G0 is None: + G0 = nx.outer(p, q) * m / (nx.sum(p) * nx.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + + else: + # Check marginals of G0 + assert nx.any(nx.sum(G0, 1) <= p) + assert nx.any(nx.sum(G0, 0) <= q) + + if symmetric is None: + symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) + + # Setup gradient computation + fC1, fC2, hC1, hC2 = _transform_matrix(C1, C2, loss_fun, nx) + fC2t = fC2.T + if not symmetric: + fC1t, hC1t, hC2t = fC1.T, hC1.T, hC2.T + + ones_p = nx.ones(p.shape[0], type_as=p) + ones_q = nx.ones(q.shape[0], type_as=q) + + def f(G): + pG = nx.sum(G, 1) + qG = nx.sum(G, 0) + constC1 = nx.outer(nx.dot(fC1, pG), ones_q) + constC2 = nx.outer(ones_p, nx.dot(qG, fC2t)) + return gwloss(constC1 + constC2, hC1, hC2, G, nx) + + if symmetric: + def df(G): + pG = nx.sum(G, 1) + qG = nx.sum(G, 0) + constC1 = nx.outer(nx.dot(fC1, pG), ones_q) + constC2 = nx.outer(ones_p, nx.dot(qG, fC2t)) + return gwggrad(constC1 + constC2, hC1, hC2, G, nx) + else: + + def df(G): + pG = nx.sum(G, 1) + qG = nx.sum(G, 0) + constC1 = nx.outer(nx.dot(fC1, pG), ones_q) + constC2 = nx.outer(ones_p, nx.dot(qG, fC2t)) + constC1t = nx.outer(nx.dot(fC1t, pG), ones_q) + constC2t = nx.outer(ones_p, nx.dot(qG, fC2)) + + return 0.5 * ( + gwggrad(constC1 + constC2, hC1, hC2, G, nx) + + gwggrad(constC1t + constC2t, hC1t, hC2t, G, nx)) + + cpt = 0 + err = 1 + + loge = {'err': []} + + while (err > tol and cpt < numItermax): + Gprev = G0 + M_entr = df(G0) + G0 = entropic_partial_wasserstein(p, q, M_entr, reg, m) + if cpt % 10 == 0: # to speed up the computations + err = np.linalg.norm(G0 - Gprev) + if log: + loge['err'].append(err) + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}|{:12s}'.format( + 'It.', 'Err', 'Loss') + '\n' + '-' * 31) + print('{:5d}|{:8e}|{:8e}'.format(cpt, err, f(G0))) + + cpt += 1 + + if log: + loge['partial_gw_dist'] = f(G0) + return G0, loge + else: + return G0 + + +def entropic_partial_gromov_wasserstein2( + C1, C2, p=None, q=None, reg=1., m=None, loss_fun='square_loss', G0=None, + numItermax=1000, tol=1e-7, symmetric=None, log=False, verbose=False): + r""" + Returns the partial Gromov-Wasserstein discrepancy between + :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + + The function solves the following optimization problem: + + .. math:: + PGW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, + \mathbf{C_2}_{j,l})\cdot + \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) + + .. math:: + s.t. \ \gamma &\geq 0 + + \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + + where : + + - :math:`\mathbf{C_1}` is the metric cost matrix in the source space + - :math:`\mathbf{C_2}` is the metric cost matrix in the target space + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights + - `L`: Loss function to account for the misfit between the similarity matrices. + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - `m` is the amount of mass to be transported + + The formulation of the GW problem has been proposed in + :ref:`[12] ` and the + partial GW in :ref:`[29] ` + + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + reg: float + entropic regularization parameter + m : float, optional + Amount of mass to be transported (default: + :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + loss_fun : str, optional + Loss function used for the solver either 'square_loss' or 'kl_loss'. + G0 : ndarray, shape (ns, nt), optional + Initialization of the transportation matrix + numItermax : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + log : bool, optional + return log if True + verbose : bool, optional + Print information along iterations + + + Returns + ------- + partial_gw_dist: float + Partial Gromov-Wasserstein distance + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + >>> from ot.gromov import entropic_partial_gromov_wasserstein2 + >>> import scipy as sp + >>> a = np.array([0.25] * 4) + >>> b = np.array([0.25] * 4) + >>> x = np.array([1,2,100,200]).reshape((-1,1)) + >>> y = np.array([3,2,98,199]).reshape((-1,1)) + >>> C1 = sp.spatial.distance.cdist(x, x) + >>> C2 = sp.spatial.distance.cdist(y, y) + >>> np.round(entropic_partial_gromov_wasserstein2(C1, C2, a, b, 1e2), 2) + 3.75 + + + .. _references-entropic-partial-gromov-wasserstein2: + References + ---------- + .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. + """ + + partial_gw, log_gw = entropic_partial_gromov_wasserstein( + C1, C2, p, q, reg, m, loss_fun, G0, numItermax, tol, + symmetric, True, verbose) + + log_gw['T'] = partial_gw + + if log: + return log_gw['partial_gw_dist'], log_gw + else: + return log_gw['partial_gw_dist'] diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 96f776cb1..c509a1046 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -164,7 +164,7 @@ def df(G): marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2)) return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) - def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, ones_p, M=0., reg=1., fC2t=fC2t, nx=nx, **kwargs) if log: @@ -436,7 +436,7 @@ def df(G): marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2)) return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) - def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): return solve_semirelaxed_gromov_linesearch( G, deltaG, cost_G, hC1, hC2, ones_p, M=(1 - alpha) * M, reg=alpha, fC2t=fC2t, nx=nx, **kwargs) diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index fb07bb1ef..31cd0fd90 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -34,11 +34,11 @@ import warnings -def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None): - r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation +def _transform_matrix(C1, C2, loss_fun='square_loss', nx=None): + r"""Return transformed structure matrices for Gromov-Wasserstein fast computation - Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the - selected loss function as the loss function of Gromov-Wasserstein discrepancy. + Returns the matrices involved in the computation of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2})` + with the selected loss function as the loss function of Gromov-Wasserstein discrepancy. The matrices are computed as described in Proposition 1 in :ref:`[12] ` @@ -46,7 +46,6 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None): - :math:`\mathbf{C_1}`: Metric cost matrix in the source space - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - - :math:`\mathbf{T}`: A coupling between those two spaces The square-loss function :math:`L(a, b) = |a - b|^2` is read as : @@ -82,10 +81,6 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None): Metric cost matrix in the source space C2 : array-like, shape (nt, nt) Metric cost matrix in the target space - p : array-like, shape (ns,) - Probability distribution in the source space - q : array-like, shape (nt,) - Probability distribution in the target space loss_fun : str, optional Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss') nx : backend, optional @@ -93,15 +88,17 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None): Returns ------- - constC : array-like, shape (ns, nt) - Constant :math:`\mathbf{C}` matrix in Eq. (6) + fC1 : array-like, shape (ns, ns) + :math:`\mathbf{f1}(\mathbf{C1})` matrix in Eq. (6) + fC2 : array-like, shape (nt, nt) + :math:`\mathbf{f2}(\mathbf{C2})` matrix in Eq. (6) hC1 : array-like, shape (ns, ns) :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) hC2 : array-like, shape (nt, nt) :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) - .. _references-init-matrix: + .. _references-transform_matrix: References ---------- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, @@ -110,8 +107,8 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None): """ if nx is None: - C1, C2, p, q = list_to_array(C1, C2, p, q) - nx = get_backend(C1, C2, p, q) + C1, C2 = list_to_array(C1, C2) + nx = get_backend(C1, C2) if loss_fun == 'square_loss': def f1(a): @@ -127,7 +124,7 @@ def h2(b): return 2 * b elif loss_fun == 'kl_loss': def f1(a): - return a * nx.log(a + 1e-16) - a + return a * nx.log(a + 1e-18) - a def f2(b): return b @@ -136,21 +133,107 @@ def h1(a): return a def h2(b): - return nx.log(b + 1e-16) + return nx.log(b + 1e-18) else: raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") + fC1 = f1(C1) + fC2 = f2(C2) + hC1 = h1(C1) + hC2 = h2(C2) + + return fC1, fC2, hC1, hC2 + + +def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None): + r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation + + Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the + selected loss function as the loss function of Gromov-Wasserstein discrepancy. + + The matrices are computed as described in Proposition 1 in :ref:`[12] ` + + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{T}`: A coupling between those two spaces + + The square-loss function :math:`L(a, b) = |a - b|^2` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a^2 + + f_2(b) &= b^2 + + h_1(a) &= a + + h_2(b) &= 2b + + The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as : + + .. math:: + + L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b) + + \mathrm{with} \ f_1(a) &= a \log(a) - a + + f_2(b) &= b + + h_1(a) &= a + + h_2(b) &= \log(b) + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,) + Probability distribution in the source space + q : array-like, shape (nt,) + Probability distribution in the target space + loss_fun : str, optional + Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss') + nx : backend, optional + If let to its default value None, a backend test will be conducted. + + Returns + ------- + constC : array-like, shape (ns, nt) + Constant :math:`\mathbf{C}` matrix in Eq. (6) + hC1 : array-like, shape (ns, ns) + :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6) + hC2 : array-like, shape (nt, nt) + :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6) + + + .. _references-init-matrix: + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + if nx is None: + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + fC1, fC2, hC1, hC2 = _transform_matrix(C1, C2, loss_fun, nx) constC1 = nx.dot( - nx.dot(f1(C1), nx.reshape(p, (-1, 1))), + nx.dot(fC1, nx.reshape(p, (-1, 1))), nx.ones((1, len(q)), type_as=q) ) constC2 = nx.dot( nx.ones((len(p), 1), type_as=p), - nx.dot(nx.reshape(q, (1, -1)), f2(C2).T) + nx.dot(nx.reshape(q, (1, -1)), fC2.T) ) constC = constC1 + constC2 - hC1 = h1(C1) - hC2 = h2(C2) return constC, hC1, hC2 @@ -353,39 +436,12 @@ def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None): C1, C2, p = list_to_array(C1, C2, p) nx = get_backend(C1, C2, p) - if loss_fun == 'square_loss': - def f1(a): - return (a**2) - - def f2(b): - return (b**2) - - def h1(a): - return a + fC1, fC2, hC1, hC2 = _transform_matrix(C1, C2, loss_fun, nx) - def h2(b): - return 2 * b - elif loss_fun == 'kl_loss': - def f1(a): - return a * nx.log(a + 1e-16) - a - - def f2(b): - return b - - def h1(a): - return a - - def h2(b): - return nx.log(b + 1e-16) - else: - raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") - - constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))), + constC = nx.dot(nx.dot(fC1, nx.reshape(p, (-1, 1))), nx.ones((1, C2.shape[0]), type_as=p)) - hC1 = h1(C1) - hC2 = h2(C2) - fC2t = f2(C2).T + fC2t = fC2.T return constC, hC1, hC2, fC2t diff --git a/ot/optim.py b/ot/optim.py index a5f88bb29..d4db59b68 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -127,7 +127,7 @@ def phi(alpha1): def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None, numItermax=200, stopThr=1e-9, - stopThr2=1e-9, verbose=False, log=False, **kwargs): + stopThr2=1e-9, verbose=False, log=False, nx=None, **kwargs): r""" Solve the general regularized OT problem or its semi-relaxed version with conditional gradient or generalized conditional gradient depending on the @@ -189,13 +189,44 @@ def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_sea Entropic Regularization term >0. Ignored if set to None. lp_solver: function, linear program solver for direction finding of the (generalized) conditional gradient. - If set to emd will solve the general regularized OT problem using cg. - If set to lp_semi_relaxed_OT will solve the general regularized semi-relaxed OT problem using cg. - If set to sinkhorn will solve the general regularized OT problem using generalized cg. + This function must take the form `lp_solver(a, b, Mi, **kwargs)` with p: + `a` and `b` are sample weights in both domains; `Mi` is the gradient of + the regularized objective; optimal arguments via kwargs. + It must output an admissible transport plan. + + For instance, for the general regularized OT problem with conditional gradient :ref:`[1] `: + + def lp_solver(a, b, M, **kwargs): + return ot.emd(a, b, M) + + or with the generalized conditional gradient instead :ref:`[5, 7] `: + + def lp_solver(a, b, Mi, **kwargs): + return ot.sinkhorn(a, b, Mi) + line_search: function, - Function to find the optimal step. Currently used instances are: - line_search_armijo (generic solver). solve_gromov_linesearch for (F)GW problem. - solve_semirelaxed_gromov_linesearch for sr(F)GW problem. gcg_linesearch for the Generalized cg. + Function to find the optimal step. This function must take the form + `line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs)` with: `cost` + the cost function, `G` the transport plan, `deltaG` the conditional + gradient direction given by lp_solver, `Mi` the gradient of regularized + objective, `cost_G` the cost at G, `df_G` the gradient of the regularizer + at G. Two types of outputs are supported: + + Instances such as `ot.optim.line_search_armijo` (generic solver), + `ot.gromov.solve_gromov_linesearch` (FGW problems), + `solve_semirelaxed_gromov_linesearch` (srFGW problems) and + `gcg_linesearch` (generalized cg), output : the line-search step alpha, + the number of iterations used in the solver if applicable and the loss + value at step alpha. These can be called e.g as: + + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): + return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs) + + Instances such as `ot.gromov.solve_partial_gromov_linesearch` for partial + (F)GW problems add as finale output, the next step gradient reading as + a convex combination of previously computed gradients, taking advantage of the regularizer + quadratic form. + G0 : array-like, shape (ns,nt), optional initial guess (default is indep joint density) numItermax : int, optional @@ -208,6 +239,8 @@ def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_sea Print information along iterations log : bool, optional record log if True + nx : backend, optional + If let to its default value None, the backend will be deduced from other inputs. **kwargs : dict Parameters for linesearch @@ -236,10 +269,11 @@ def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_sea ot.bregman.sinkhorn : Entropic regularized optimal transport """ - if isinstance(M, int) or isinstance(M, float): - nx = get_backend(a, b) - else: - nx = get_backend(a, b, M) + if nx is None: + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(a, b) + else: + nx = get_backend(a, b, M) loop = 1 @@ -262,6 +296,7 @@ def cost(G): if log: log['loss'].append(cost_G) + df_G = None it = 0 if verbose: @@ -274,20 +309,27 @@ def cost(G): it += 1 old_cost_G = cost_G # problem linearization - Mi = M + reg1 * df(G) + if df_G is None: + df_G = df(G) + Mi = M + reg1 * df_G if not (reg2 is None): Mi = Mi + reg2 * (1 + nx.log(G)) - # set M positive - Mi = Mi + nx.min(Mi) # solve linear program Gc, innerlog_ = lp_solver(a, b, Mi, **kwargs) - # line search deltaG = Gc - G - alpha, fc, cost_G = line_search(cost, G, deltaG, Mi, cost_G, **kwargs) + res_line_search = line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs) + if len(res_line_search) == 3: + # the line-search does not allow to update the gradient + alpha, fc, cost_G = res_line_search + df_G = None + else: + # the line-search allows to update the gradient directly + # e.g. while using quadratic losses as the gromov-wasserstein loss + alpha, fc, cost_G, df_G = res_line_search G = G + alpha * deltaG @@ -316,9 +358,9 @@ def cost(G): return G -def cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, +def cg(a, b, M, reg, f, df, G0=None, line_search=None, numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9, - verbose=False, log=False, **kwargs): + verbose=False, log=False, nx=None, **kwargs): r""" Solve the general regularized OT problem with conditional gradient @@ -357,7 +399,7 @@ def cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, initial guess (default is indep joint density) line_search: function, Function to find the optimal step. - Default is line_search_armijo. + Default is None and calls a wrapper to line_search_armijo. numItermax : int, optional Max number of iterations numItermaxEmd : int, optional @@ -370,6 +412,8 @@ def cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, Print information along iterations log : bool, optional record log if True + nx : backend, optional + If let to its default value None, the backend will be deduced from other inputs. **kwargs : dict Parameters for linesearch @@ -393,17 +437,26 @@ def cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, ot.bregman.sinkhorn : Entropic regularized optimal transport """ + if nx is None: + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(a, b) + else: + nx = get_backend(a, b, M) + + if line_search is None: + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): + return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=nx, **kwargs) def lp_solver(a, b, M, **kwargs): return emd(a, b, M, numItermaxEmd, log=True) return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, numItermax=numItermax, stopThr=stopThr, - stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) + stopThr2=stopThr2, verbose=verbose, log=log, nx=nx, **kwargs) -def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, - numItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): +def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=None, + numItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, nx=None, **kwargs): r""" Solve the general regularized and semi-relaxed OT problem with conditional gradient @@ -440,7 +493,7 @@ def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, initial guess (default is indep joint density) line_search: function, Function to find the optimal step. - Default is the armijo line-search. + Default is None and calls a wrapper to line_search_armijo. numItermax : int, optional Max number of iterations stopThr : float, optional @@ -451,6 +504,8 @@ def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, Print information along iterations log : bool, optional record log if True + nx : backend, optional + If let to its default value None, the backend will be deduced from other inputs. **kwargs : dict Parameters for linesearch @@ -471,8 +526,15 @@ def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, International Conference on Learning Representations (ICLR), 2021. """ + if nx is None: + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(a, b) + else: + nx = get_backend(a, b, M) - nx = get_backend(a, b) + if line_search is None: + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): + return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=nx, **kwargs) def lp_solver(a, b, Mi, **kwargs): # get minimum by rows as binary mask @@ -485,6 +547,110 @@ def lp_solver(a, b, Mi, **kwargs): # return by default an empty inner_log return Gc, {} + return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, + numItermax=numItermax, stopThr=stopThr, + stopThr2=stopThr2, verbose=verbose, log=log, nx=nx, **kwargs) + + +def partial_cg(a, b, a_extended, b_extended, M, reg, f, df, G0=None, line_search=line_search_armijo, + numItermax=200, stopThr=1e-9, stopThr2=1e-9, warn=True, verbose=False, log=False, **kwargs): + r""" + Solve the general regularized partial OT problem with conditional gradient + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot f(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma \mathbf{1} &= \mathbf{b} + + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\} + + \gamma &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights + - `m` is the amount of mass to be transported + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] ` + + Parameters + ---------- + a : array-like, shape (ns,) + samples weights in the source domain + b : array-like, shape (nt,) + currently estimated samples weights in the target domain + a_extended : array-like, shape (ns + nb_dummies,) + samples weights in the source domain with added dummy nodes + b_extended : array-like, shape (nt + nb_dummies,) + currently estimated samples weights in the target domain with added dummy nodes + M : array-like, shape (ns, nt) + loss matrix + reg : float + Regularization term >0 + G0 : array-like, shape (ns,nt), optional + initial guess (default is indep joint density) + line_search: function, + Function to find the optimal step. + Default is the armijo line-search. + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + warn: bool, optional. + Whether to raise a warning when EMD did not converge. + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + **kwargs : dict + Parameters for linesearch + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + .. _references-partial-cg: + References + ---------- + + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. + + """ + n, m = a.shape[0], b.shape[0] + n_extended, m_extended = a_extended.shape[0], b_extended.shape[0] + nb_dummies = n_extended - n + + def lp_solver(a, b, Mi, **kwargs): + # add dummy nodes to Mi + Mi_extended = np.zeros((n_extended, m_extended), dtype=Mi.dtype) + Mi_extended[:n, :m] = Mi + Mi_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2 + + G_extended, log_ = emd(a_extended, b_extended, Mi_extended, numItermax, log=True) + Gc = G_extended[:n, :m] + + if warn: + if log_['warning'] is not None: + raise ValueError("Error in the EMD resolution: try to increase the" + " number of dummy points") + + return Gc, log_ + return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, numItermax=numItermax, stopThr=stopThr, stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) @@ -569,7 +735,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, def lp_solver(a, b, Mi, **kwargs): return sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax, log=True, **kwargs) - def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs) return generic_conditional_gradient(a, b, M, f, df, reg2, reg1, lp_solver, line_search, G0=G0, diff --git a/ot/partial.py b/ot/partial.py index a3b25a856..a409a74d1 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -4,13 +4,14 @@ """ # Author: Laetitia Chapel -# Yikun Bai < yikun.bai@vanderbilt.edu > -# Cédric Vincent-Cuaz +# Yikun Bai < yikun.bai@vanderbilt.edu > +# Cédric Vincent-Cuaz from .utils import list_to_array from .backend import get_backend from .lp import emd import numpy as np +import warnings # License: MIT License @@ -413,10 +414,170 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): return nx.sum(partial_gw * M) +def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, + stopThr=1e-100, verbose=False, log=False): + r""" + Solves the partial optimal transport problem + and returns the OT plan + + The function considers the following problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, + \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma) + + s.t. \gamma \mathbf{1} &\leq \mathbf{a} \\ + \gamma^T \mathbf{1} &\leq \mathbf{b} \\ + \gamma &\geq 0 \\ + \mathbf{1}^T \gamma^T \mathbf{1} = m + &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\ + + where : + + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + - `m` is the amount of mass to be transported + + The formulation of the problem has been proposed in + :ref:`[3] ` (prop. 5) + + + Parameters + ---------- + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : np.ndarray (dim_b,) + Unnormalized histograms of dimension `dim_b` + M : np.ndarray (dim_a, dim_b) + cost matrix + reg : float + Regularization term > 0 + m : float, optional + Amount of mass to be transported + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (dim_a, dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + + + Examples + -------- + >>> import ot + >>> a = [.1, .2] + >>> b = [.1, .1] + >>> M = [[0., 1.], [2., 3.]] + >>> np.round(entropic_partial_wasserstein(a, b, M, 1, 0.1), 2) + array([[0.06, 0.02], + [0.01, 0. ]]) + + + .. _references-entropic-partial-wasserstein: + References + ---------- + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. + (2015). Iterative Bregman projections for regularized transportation + problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + + See Also + -------- + ot.partial.partial_wasserstein: exact Partial Wasserstein + """ + + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) + + dim_a, dim_b = M.shape + dx = nx.ones(dim_a, type_as=a) + dy = nx.ones(dim_b, type_as=b) + + if len(a) == 0: + a = nx.ones(dim_a, type_as=a) / dim_a + if len(b) == 0: + b = nx.ones(dim_b, type_as=b) / dim_b + + if m is None: + m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0 + if m < 0: + raise ValueError("Problem infeasible. Parameter m should be greater" + " than 0.") + if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): + raise ValueError("Problem infeasible. Parameter m should lower or" + " equal than min(|a|_1, |b|_1).") + + log_e = {'err': []} + + if nx.__name__ == "numpy": + # Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + np.multiply(K, m / np.sum(K), out=K) + else: + K = nx.exp(-M / reg) + K = K * m / nx.sum(K) + + err, cpt = 1, 0 + q1 = nx.ones(K.shape, type_as=K) + q2 = nx.ones(K.shape, type_as=K) + q3 = nx.ones(K.shape, type_as=K) + + while (err > stopThr and cpt < numItermax): + Kprev = K + K = K * q1 + K1 = nx.dot(nx.diag(nx.minimum(a / nx.sum(K, axis=1), dx)), K) + q1 = q1 * Kprev / K1 + K1prev = K1 + K1 = K1 * q2 + K2 = nx.dot(K1, nx.diag(nx.minimum(b / nx.sum(K1, axis=0), dy))) + q2 = q2 * K1prev / K2 + K2prev = K2 + K2 = K2 * q3 + K = K2 * (m / nx.sum(K2)) + q3 = q3 * K2prev / K + + if nx.any(nx.isnan(K)) or nx.any(nx.isinf(K)): + print('Warning: numerical errors at iteration', cpt) + break + if cpt % 10 == 0: + err = nx.norm(Kprev - K) + if log: + log_e['err'].append(err) + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 11) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt = cpt + 1 + log_e['partial_w_dist'] = nx.sum(M * K) + if log: + return K, log_e + else: + return K + + def gwgrad_partial(C1, C2, T): """Compute the GW gradient. Note: we can not use the trick in :ref:`[12] ` as the marginals may not sum to 1. + .. note:: This function will be deprecated in a near future, please use + `ot.gromov.gwggrad` instead. + Parameters ---------- C1: array of shape (n_p,n_p) @@ -441,6 +602,11 @@ def gwgrad_partial(C1, C2, T): "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ + warnings.warn( + "This function will be deprecated in a near future, please use " + "ot.gromov.gwggrad` instead.", + stacklevel=2 + ) cC1 = np.dot(C1 ** 2 / 2, np.dot(T, np.ones(C2.shape[0]).reshape(-1, 1))) cC2 = np.dot(np.dot(np.ones(C1.shape[0]).reshape(1, -1), T), C2 ** 2 / 2) constC = cC1 + cC2 @@ -452,6 +618,9 @@ def gwgrad_partial(C1, C2, T): def gwloss_partial(C1, C2, T): """Compute the GW loss. + .. note:: This function will be deprecated in a near future, please use + `ot.gromov.gwloss` instead. + Parameters ---------- C1: array of shape (n_p,n_p) @@ -467,6 +636,12 @@ def gwloss_partial(C1, C2, T): ------- GW loss """ + warnings.warn( + "This function will be deprecated in a near future, please use " + "ot.gromov.gwloss` instead.", + stacklevel=2 + ) + g = gwgrad_partial(C1, C2, T) * 0.5 return np.sum(g * T) @@ -502,6 +677,8 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, The formulation of the problem has been proposed in :ref:`[29] ` + .. note:: This function will be deprecated in a near future, please use + `ot.gromov.partial_gromov_wasserstein` instead. Parameters ---------- @@ -573,6 +750,11 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, NeurIPS. """ + warnings.warn( + "This function will be deprecated in a near future, please use " + "ot.gromov.partial_gromov_wasserstein` instead.", + stacklevel=2 + ) if m is None: m = np.min((np.sum(p), np.sum(q))) @@ -683,6 +865,8 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, The formulation of the problem has been proposed in :ref:`[29] ` + .. note:: This function will be deprecated in a near future, please use + `ot.gromov.partial_gromov_wasserstein2` instead. Parameters ---------- @@ -755,6 +939,11 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, NeurIPS. """ + warnings.warn( + "This function will be deprecated in a near future, please use " + "ot.gromov.partial_gromov_wasserstein2` instead.", + stacklevel=2 + ) partial_gw, log_gw = partial_gromov_wasserstein(C1, C2, p, q, m, nb_dummies, G0, thres, @@ -769,163 +958,6 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, return log_gw['partial_gw_dist'] -def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, - stopThr=1e-100, verbose=False, log=False): - r""" - Solves the partial optimal transport problem - and returns the OT plan - - The function considers the following problem: - - .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, - \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma) - - s.t. \gamma \mathbf{1} &\leq \mathbf{a} \\ - \gamma^T \mathbf{1} &\leq \mathbf{b} \\ - \gamma &\geq 0 \\ - \mathbf{1}^T \gamma^T \mathbf{1} = m - &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\ - - where : - - - :math:`\mathbf{M}` is the metric cost matrix - - :math:`\Omega` is the entropic regularization term, - :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - - `m` is the amount of mass to be transported - - The formulation of the problem has been proposed in - :ref:`[3] ` (prop. 5) - - - Parameters - ---------- - a : np.ndarray (dim_a,) - Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) - Unnormalized histograms of dimension `dim_b` - M : np.ndarray (dim_a, dim_b) - cost matrix - reg : float - Regularization term > 0 - m : float, optional - Amount of mass to be transported - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - - - Returns - ------- - gamma : (dim_a, dim_b) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary returned only if `log` is `True` - - - Examples - -------- - >>> import ot - >>> a = [.1, .2] - >>> b = [.1, .1] - >>> M = [[0., 1.], [2., 3.]] - >>> np.round(entropic_partial_wasserstein(a, b, M, 1, 0.1), 2) - array([[0.06, 0.02], - [0.01, 0. ]]) - - - .. _references-entropic-partial-wasserstein: - References - ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. - (2015). Iterative Bregman projections for regularized transportation - problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. - - See Also - -------- - ot.partial.partial_wasserstein: exact Partial Wasserstein - """ - - a, b, M = list_to_array(a, b, M) - - nx = get_backend(a, b, M) - - dim_a, dim_b = M.shape - dx = nx.ones(dim_a, type_as=a) - dy = nx.ones(dim_b, type_as=b) - - if len(a) == 0: - a = nx.ones(dim_a, type_as=a) / dim_a - if len(b) == 0: - b = nx.ones(dim_b, type_as=b) / dim_b - - if m is None: - m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0 - if m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" - " than 0.") - if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): - raise ValueError("Problem infeasible. Parameter m should lower or" - " equal than min(|a|_1, |b|_1).") - - log_e = {'err': []} - - if nx.__name__ == "numpy": - # Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - np.multiply(K, m / np.sum(K), out=K) - else: - K = nx.exp(-M / reg) - K = K * m / nx.sum(K) - - err, cpt = 1, 0 - q1 = nx.ones(K.shape, type_as=K) - q2 = nx.ones(K.shape, type_as=K) - q3 = nx.ones(K.shape, type_as=K) - - while (err > stopThr and cpt < numItermax): - Kprev = K - K = K * q1 - K1 = nx.dot(nx.diag(nx.minimum(a / nx.sum(K, axis=1), dx)), K) - q1 = q1 * Kprev / K1 - K1prev = K1 - K1 = K1 * q2 - K2 = nx.dot(K1, nx.diag(nx.minimum(b / nx.sum(K1, axis=0), dy))) - q2 = q2 * K1prev / K2 - K2prev = K2 - K2 = K2 * q3 - K = K2 * (m / nx.sum(K2)) - q3 = q3 * K2prev / K - - if nx.any(nx.isnan(K)) or nx.any(nx.isinf(K)): - print('Warning: numerical errors at iteration', cpt) - break - if cpt % 10 == 0: - err = nx.norm(Kprev - K) - if log: - log_e['err'].append(err) - if verbose: - if cpt % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 11) - print('{:5d}|{:8e}|'.format(cpt, err)) - - cpt = cpt + 1 - log_e['partial_w_dist'] = nx.sum(M * K) - if log: - return K, log_e - else: - return K - - def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, numItermax=1000, tol=1e-7, log=False, verbose=False): @@ -964,6 +996,9 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, :ref:`[12] ` and the partial GW in :ref:`[29] ` + .. note:: This function will be deprecated in a near future, please use + `ot.gromov.entropic_partial_gromov_wasserstein` instead. + Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -1035,6 +1070,12 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, ot.partial.partial_gromov_wasserstein: exact Partial Gromov-Wasserstein """ + warnings.warn( + "This function will be deprecated in a near future, please use " + "ot.gromov.entropic_partial_gromov_wasserstein` instead.", + stacklevel=2 + ) + if G0 is None: G0 = np.outer(p, q) @@ -1113,6 +1154,8 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, :ref:`[12] ` and the partial GW in :ref:`[29] ` + .. note:: This function will be deprecated in a near future, please use + `ot.gromov.entropic_partial_gromov_wasserstein2` instead. Parameters ---------- @@ -1173,6 +1216,11 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, Transport with Applications on Positive-Unlabeled Learning". NeurIPS. """ + warnings.warn( + "This function will be deprecated in a near future, please use " + "ot.gromov.entropic_partial_gromov_wasserstein2` instead.", + stacklevel=2 + ) partial_gw, log_gw = entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m, G0, numItermax, diff --git a/ot/solvers.py b/ot/solvers.py index ac2fcbb88..37b1b93df 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -17,8 +17,9 @@ entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2, semirelaxed_gromov_wasserstein2, semirelaxed_fused_gromov_wasserstein2, entropic_semirelaxed_fused_gromov_wasserstein2, - entropic_semirelaxed_gromov_wasserstein2) -from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2 + entropic_semirelaxed_gromov_wasserstein2, + partial_gromov_wasserstein2, + entropic_partial_gromov_wasserstein2) from .gaussian import empirical_bures_wasserstein_distance from .factored import factored_optimal_transport from .lowrank import lowrank_sinkhorn @@ -779,10 +780,6 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise (ValueError('Partial GW mass given in reg is too large')) - if loss.lower() != 'l2': - raise (NotImplementedError('Partial GW only implemented with L2 loss')) - if symmetric is not None: - raise (NotImplementedError('Partial GW only implemented with symmetric=True')) # default values for solver if max_iter is None: @@ -790,7 +787,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if tol is None: tol = 1e-7 - value, log = partial_gromov_wasserstein2(Ca, Cb, a, b, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose) + value, log = partial_gromov_wasserstein2(Ca, Cb, a, b, m=unbalanced, loss_fun=loss_fun, log=True, numItermax=max_iter, G0=plan_init, tol=tol, symmetric=symmetric, verbose=verbose) value_quad = value plan = log['T'] @@ -902,10 +899,6 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise (ValueError('Partial GW mass given in reg is too large')) - if loss.lower() != 'l2': - raise (NotImplementedError('Partial GW only implemented with L2 loss')) - if symmetric is not None: - raise (NotImplementedError('Partial GW only implemented with symmetric=True')) # default values for solver if max_iter is None: @@ -913,7 +906,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if tol is None: tol = 1e-7 - value_quad, log = entropic_partial_gromov_wasserstein2(Ca, Cb, a, b, reg=reg, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose) + value_quad, log = entropic_partial_gromov_wasserstein2(Ca, Cb, a, b, reg=reg, loss_fun=loss_fun, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, symmetric=symmetric, verbose=verbose) value_quad = value plan = log['T'] diff --git a/test/gromov/test_gw.py b/test/gromov/test_gw.py index 5b858f307..4f3dff14b 100644 --- a/test/gromov/test_gw.py +++ b/test/gromov/test_gw.py @@ -317,7 +317,7 @@ def f(G): def df(G): return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None) - def line_search(cost, G, deltaG, Mi, cost_G): + def line_search(cost, G, deltaG, Mi, cost_G, df_G): return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=0., reg=1., nx=None) # feed the precomputed local optimum Gb to cg res, log = ot.optim.cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) @@ -751,12 +751,12 @@ def f(G): def df(G): return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None) - def line_search(cost, G, deltaG, Mi, cost_G): + def line_search(cost, G, deltaG, Mi, cost_G, df_G): return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=(1 - alpha) * Mb, reg=alpha, nx=None) # feed the precomputed local optimum Gb to cg res, log = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) - def line_search(cost, G, deltaG, Mi, cost_G): + def line_search(cost, G, deltaG, Mi, cost_G, df_G): return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=None) # feed the precomputed local optimum Gb to cg res_armijo, log_armijo = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) diff --git a/test/gromov/test_partial.py b/test/gromov/test_partial.py new file mode 100644 index 000000000..9a6712666 --- /dev/null +++ b/test/gromov/test_partial.py @@ -0,0 +1,306 @@ +""" Tests for gromov._partial.py """ + +# Author: +# Laetitia Chapel +# Cédric Vincent-Cuat +# +# License: MIT License + +import numpy as np +import scipy as sp +import ot +import pytest + + +def test_raise_errors(): + + n_samples = 20 # nb samples (gaussian) + n_noise = 20 # nb of samples (noise) + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + rng = np.random.RandomState(42) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=rng) + xs = np.append(xs, (rng.rand(n_noise, 2) + 1) * 4).reshape((-1, 2)) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=rng) + xt = np.append(xt, (rng.rand(n_noise, 2) + 1) * -3).reshape((-1, 2)) + + M = ot.dist(xs, xt) + + p = ot.unif(n_samples + n_noise) + q = ot.unif(n_samples + n_noise) + + with pytest.raises(ValueError): + ot.gromov.partial_gromov_wasserstein(M, M, p, q, m=2, log=True) + + with pytest.raises(ValueError): + ot.gromov.partial_gromov_wasserstein(M, M, p, q, m=-1, log=True) + + with pytest.raises(ValueError): + ot.gromov.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, + log=True) + + with pytest.raises(ValueError): + ot.gromov.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, + log=True) + + +def test_partial_gromov_wasserstein(nx): + rng = np.random.RandomState(42) + n_samples = 20 # nb samples + n_noise = 10 # nb of samples (noise) + + p = ot.unif(n_samples + n_noise) + psub = ot.unif(n_samples - 5 + n_noise) + q = ot.unif(n_samples + n_noise) + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + mu_t = np.array([0, 0, 0]) + cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + # clean samples + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=rng) + P = sp.linalg.sqrtm(cov_t) + xt = rng.randn(n_samples, 3).dot(P) + mu_t + # add noise + xs = np.concatenate((xs, ((rng.rand(n_noise, 2) + 1) * 4)), axis=0) + xt = np.concatenate((xt, ((rng.rand(n_noise, 3) + 1) * 10)), axis=0) + xt2 = xs[::-1].copy() + + C1 = ot.dist(xs, xs) + C1sub = ot.dist(xs[5:], xs[5:]) + + C2 = ot.dist(xt, xt) + C3 = ot.dist(xt2, xt2) + + m = 2. / 3. + + C1b, C1subb, C2b, C3b, pb, psubb, qb = nx.from_numpy(C1, C1sub, C2, C3, p, psub, q) + + G0 = np.outer(p, q) * m / (np.sum(p) * np.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + G0b = nx.from_numpy(G0) + + # check consistency across backends and stability w.r.t loss/marginals/sym + list_sym = [True, None] + for i, loss_fun in enumerate(['square_loss', 'kl_loss']): + + res, log = ot.gromov.partial_gromov_wasserstein( + C1, C3, p=p, q=None, m=m, loss_fun=loss_fun, n_dummies=1, + G0=G0, log=True, symmetric=list_sym[i], warn=True, verbose=True) + + resb, logb = ot.gromov.partial_gromov_wasserstein( + C1b, C3b, p=None, q=qb, m=m, loss_fun=loss_fun, n_dummies=1, + G0=G0b, log=True, symmetric=False, warn=True, verbose=True) + + resb_ = nx.to_numpy(resb) + assert np.all(res.sum(1) <= p) # cf convergence wasserstein + assert np.all(res.sum(0) <= q) # cf convergence wasserstein + + try: + # precision error while doubling numbers of computations with symmetric=False + # some instability can occur with kl. to investigate further. + # changing log offset in _transform_matrix was a way to solve it + # but it also negatively affects some other solvers in the API + np.testing.assert_allclose(res, resb_, rtol=1e-4) + except AssertionError: + pass + + # tests with different number of samples across spaces + m = 2. / 3. + res, log = ot.gromov.partial_gromov_wasserstein( + C1, C1sub, p=p, q=psub, m=m, log=True) + + resb, logb = ot.gromov.partial_gromov_wasserstein( + C1b, C1subb, p=pb, q=psubb, m=m, log=True) + + resb_ = nx.to_numpy(resb) + np.testing.assert_allclose(res, resb_, rtol=1e-4) + assert np.all(res.sum(1) <= p) # cf convergence wasserstein + assert np.all(res.sum(0) <= psub) # cf convergence wasserstein + np.testing.assert_allclose( + np.sum(res), m, atol=1e-15) + + # Edge cases - tests with m=1 set by default (coincide with gw) + m = 1 + res0 = ot.gromov.partial_gromov_wasserstein( + C1, C2, p, q, m=m, log=False) + res0b, log0b = ot.gromov.partial_gromov_wasserstein( + C1b, C2b, pb, qb, m=None, log=True) + G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss') + np.testing.assert_allclose(G, res0, rtol=1e-4) + np.testing.assert_allclose(res0b, res0, rtol=1e-4) + + # tests for pGW2 + for loss_fun in ['square_loss', 'kl_loss']: + w0, log0 = ot.gromov.partial_gromov_wasserstein2( + C1, C2, p=None, q=q, m=m, loss_fun=loss_fun, log=True) + w0_val = ot.gromov.partial_gromov_wasserstein2( + C1b, C2b, p=pb, q=None, m=m, loss_fun=loss_fun, log=False) + np.testing.assert_allclose(w0, w0_val, rtol=1e-4) + + # tests integers + C1_int = C1.astype(int) + C1b_int = nx.from_numpy(C1_int) + C2_int = C2.astype(int) + C2b_int = nx.from_numpy(C2_int) + + res0b, log0b = ot.gromov.partial_gromov_wasserstein( + C1b_int, C2b_int, pb, qb, m=m, log=True) + + assert nx.to_numpy(res0b).dtype == C1_int.dtype + + +def test_partial_partial_gromov_linesearch(nx): + rng = np.random.RandomState(42) + n_samples = 20 # nb samples + n_noise = 10 # nb of samples (noise) + + p = ot.unif(n_samples + n_noise) + q = ot.unif(n_samples + n_noise) + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + mu_t = np.array([0, 0, 0]) + cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=rng) + xs = np.concatenate((xs, ((rng.rand(n_noise, 2) + 1) * 4)), axis=0) + P = sp.linalg.sqrtm(cov_t) + xt = rng.randn(n_samples, 3).dot(P) + mu_t + xt = np.concatenate((xt, ((rng.rand(n_noise, 3) + 1) * 10)), axis=0) + xt2 = xs[::-1].copy() + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + C3 = ot.dist(xt2, xt2) + + m = 2. / 3. + + C1b, C2b, C3b, pb, qb = nx.from_numpy(C1, C2, C3, p, q) + G0 = np.outer(p, q) * m / (np.sum(p) * np.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + G0b = nx.from_numpy(G0) + + # computing necessary inputs to the line-search + Gb, _ = ot.gromov.partial_gromov_wasserstein( + C1b, C2b, pb, qb, m=m, log=True) + + deltaGb = Gb - G0b + fC1, fC2, hC1, hC2 = ot.gromov._utils._transform_matrix(C1b, C2b, 'square_loss') + fC2t = fC2.T + + ones_p = nx.ones(p.shape[0], type_as=pb) + ones_q = nx.ones(p.shape[0], type_as=pb) + + constC1 = nx.outer(nx.dot(fC1, pb), ones_q) + constC2 = nx.outer(ones_p, nx.dot(qb, fC2t)) + cost_G0b = ot.gromov.gwloss(constC1 + constC2, hC1, hC2, G0b) + + df_G0b = ot.gromov.gwggrad(constC1 + constC2, hC1, hC2, G0b) + df_Gb = ot.gromov.gwggrad(constC1 + constC2, hC1, hC2, Gb) + + # perform line-search + alpha, _, cost_Gb, _ = ot.gromov.solve_partial_gromov_linesearch( + G0b, deltaGb, cost_G0b, df_G0b, df_Gb, 0., 1., + alpha_min=0., alpha_max=1.) + + np.testing.assert_allclose(alpha, 1., rtol=1e-4) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_partial_gromov_wasserstein(nx): + rng = np.random.RandomState(42) + n_samples = 20 # nb samples + n_noise = 10 # nb of samples (noise) + + p = ot.unif(n_samples + n_noise) + psub = ot.unif(n_samples - 5 + n_noise) + q = ot.unif(n_samples + n_noise) + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + mu_t = np.array([0, 0, 0]) + cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + # clean samples + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=rng) + P = sp.linalg.sqrtm(cov_t) + xt = rng.randn(n_samples, 3).dot(P) + mu_t + # add noise + xs = np.concatenate((xs, ((rng.rand(n_noise, 2) + 1) * 4)), axis=0) + xt = np.concatenate((xt, ((rng.rand(n_noise, 3) + 1) * 10)), axis=0) + xt2 = xs[::-1].copy() + + C1 = ot.dist(xs, xs) + C1sub = ot.dist(xs[5:], xs[5:]) + + C2 = ot.dist(xt, xt) + C3 = ot.dist(xt2, xt2) + + m = 2. / 3. + + C1b, C1subb, C2b, C3b, pb, psubb, qb = nx.from_numpy(C1, C1sub, C2, C3, p, psub, q) + G0 = np.outer(p, q) * m / (np.sum(p) * np.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + G0b = nx.from_numpy(G0) + + # check consistency across backends and stability w.r.t loss/marginals/sym + list_sym = [True, None] + for i, loss_fun in enumerate(['square_loss', 'kl_loss']): + res, log = ot.gromov.entropic_partial_gromov_wasserstein( + C1, C3, p=p, q=None, reg=1e4, m=m, loss_fun=loss_fun, G0=None, + log=True, symmetric=list_sym[i], verbose=True) + + resb, logb = ot.gromov.entropic_partial_gromov_wasserstein( + C1b, C3b, p=None, q=qb, reg=1e4, m=m, loss_fun=loss_fun, G0=G0b, + log=True, symmetric=False, verbose=True) + + resb_ = nx.to_numpy(resb) + try: # some instability can occur with kl. to investigate further. + np.testing.assert_allclose(res, resb_, rtol=1e-4) + except AssertionError: + pass + + assert np.all(res.sum(1) <= p) # cf convergence wasserstein + assert np.all(res.sum(0) <= q) # cf convergence wasserstein + + # tests with m is None + res = ot.gromov.entropic_partial_gromov_wasserstein( + C1, C3, p=p, q=None, reg=1e4, G0=None, log=False, + symmetric=list_sym[i], verbose=True) + + resb = ot.gromov.entropic_partial_gromov_wasserstein( + C1b, C3b, p=None, q=qb, reg=1e4, G0=None, log=False, + symmetric=False, verbose=True) + + resb_ = nx.to_numpy(resb) + np.testing.assert_allclose(res, resb_, rtol=1e-4) + np.testing.assert_allclose( + np.sum(res), 1., rtol=1e-4) + + # tests with different number of samples across spaces + m = 0.5 + res, log = ot.gromov.entropic_partial_gromov_wasserstein( + C1, C1sub, p=p, q=psub, reg=1e4, m=m, log=True) + + resb, logb = ot.gromov.entropic_partial_gromov_wasserstein( + C1b, C1subb, p=pb, q=psubb, reg=1e4, m=m, log=True) + + resb_ = nx.to_numpy(resb) + np.testing.assert_allclose(res, resb_, rtol=1e-4) + assert np.all(res.sum(1) <= p) # cf convergence wasserstein + assert np.all(res.sum(0) <= psub) # cf convergence wasserstein + np.testing.assert_allclose( + np.sum(res), m, rtol=1e-4) + + # tests for pGW2 + for loss_fun in ['square_loss', 'kl_loss']: + w0, log0 = ot.gromov.entropic_partial_gromov_wasserstein2( + C1, C2, p=None, q=q, reg=1e4, m=m, loss_fun=loss_fun, log=True) + w0_val = ot.gromov.entropic_partial_gromov_wasserstein2( + C1b, C2b, p=pb, q=None, reg=1e4, m=m, loss_fun=loss_fun, log=False) + np.testing.assert_allclose(w0, w0_val, rtol=1e-8) diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index acc49be5f..0a9e25d17 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -196,7 +196,7 @@ def df(G): marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None) - def line_search(cost, G, deltaG, Mi, cost_G): + def line_search(cost, G, deltaG, Mi, cost_G, df_G): return ot.gromov.solve_semirelaxed_gromov_linesearch( G, deltaG, cost_G, hC1b, hC2b, ones_pb, 0., 1., fC2t=fC2tb, nx=None) # feed the precomputed local optimum Gb to semirelaxed_cg @@ -423,7 +423,7 @@ def df(G): marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb)) return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None) - def line_search(cost, G, deltaG, Mi, cost_G): + def line_search(cost, G, deltaG, Mi, cost_G, df_G): return ot.gromov.solve_semirelaxed_gromov_linesearch( G, deltaG, cost_G, C1b, C2b, ones_pb, M=(1 - alpha) * Mb, reg=alpha, nx=None) # feed the precomputed local optimum Gb to semirelaxed_cg diff --git a/test/test_solvers.py b/test/test_solvers.py index 61dda87a7..a338f93a6 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -338,14 +338,10 @@ def test_solve_gromov_not_implemented(nx): # detect partial not implemented and error detect in value with pytest.raises(ValueError): ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=1.5) - with pytest.raises(NotImplementedError): - ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=0.5, symmetric=False) with pytest.raises(NotImplementedError): ot.solve_gromov(Ca, Cb, M, unbalanced_type='partial', unbalanced=0.5) with pytest.raises(ValueError): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=1.5) - with pytest.raises(NotImplementedError): - ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False) def test_solve_sample(nx):