From 7ba5f03bae0f317a6763ca4b3ae0c8a2eba911cf Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Mon, 20 Feb 2023 18:04:09 +0100 Subject: [PATCH 1/8] Allow warmstart in sinkhorn and sinkhorn_log --- ot/bregman.py | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index c33c92c83..fc715bb9c 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -364,7 +364,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, raise ValueError("Unknown method '%s'." % method) -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None, verbose=False, log=False, warn=True, **kwargs): r""" @@ -409,6 +409,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, Max number of iterations stopThr : float, optional Stop threshold on error (>0) + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual vectors. If provided, the dual vectors must be in logarithm form, + i.e. warmstart = (log_u, log_v), but not (u, v). verbose : bool, optional Print information along iterations log : bool, optional @@ -474,12 +477,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, # we assume that no distances are null except those of the diagonal of # distances - if n_hists: - u = nx.ones((dim_a, n_hists), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b + if warmstart is None: + if n_hists: + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b + else: + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b else: - u = nx.ones(dim_a, type_as=M) / dim_a - v = nx.ones(dim_b, type_as=M) / dim_b + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) K = nx.exp(M / (-reg)) @@ -546,7 +552,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) -def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, +def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None, verbose=False, log=False, warn=True, **kwargs): r""" Solve the entropic regularization optimal transport problem in log space @@ -590,6 +596,9 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, Max number of iterations stopThr : float, optional Stop threshold on error (>0) + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual vectors. If provided, the dual vectors must be in logarithm form, + i.e. warmstart = (log_u, log_v), but not (u, v). verbose : bool, optional Print information along iterations log : bool, optional @@ -656,6 +665,10 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, else: n_hists = 0 + # in case of multiple historgrams + if n_hists > 1 and warmstart is None: + warmstart = [None] * n_hists + if n_hists: # we do not want to use tensors sor we do a loop lst_loss = [] @@ -663,7 +676,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, lst_v = [] for k in range(n_hists): - res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, + res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, warmstart=warmstart[k], stopThr=stopThr, verbose=verbose, log=log, **kwargs) if log: @@ -691,9 +704,11 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, # we assume that no distances are null except those of the diagonal of # distances - - u = nx.zeros(dim_a, type_as=M) - v = nx.zeros(dim_b, type_as=M) + if warmstart is None: + u = nx.zeros(dim_a, type_as=M) + v = nx.zeros(dim_b, type_as=M) + else: + u, v = warmstart def get_logT(u, v): if n_hists: From eabeabe0ba99822c2b89fed91345ba6817eb256f Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Wed, 22 Feb 2023 12:01:08 +0100 Subject: [PATCH 2/8] Added argument for warmstart of dual vectors in Sinkhorn-based methods in --- RELEASES.md | 4 +- ot/bregman.py | 116 ++++++++++++++++++++++++++++--------------- test/test_bregman.py | 76 ++++++++++++++++++++++++++++ 3 files changed, 154 insertions(+), 42 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 4ed362556..9a377f140 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -8,8 +8,8 @@ - Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) - Added Free Support Sinkhorn Barycenter + example (PR #387) - New API for OT solver using function `ot.solve` (PR #388) -- Backend version of `ot.partial` and `ot.smooth` (PR #388) - +- Backend version of `ot.partial` and `ot.smooth` (PR #388) +- Added argument for warmstart of dual vectors in Sinkhorn-based methods in `ot.bregman` (PR #) #### Closed issues diff --git a/ot/bregman.py b/ot/bregman.py index fc715bb9c..215ade0be 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -24,7 +24,7 @@ from .backend import get_backend -def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, +def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, stopThr=1e-9, verbose=False, log=False, warn=True, **kwargs): r""" @@ -93,6 +93,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, those function for specific parameters numItermax : int, optional Max number of iterations + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, + i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional @@ -154,27 +157,27 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, """ if method.lower() == 'sinkhorn': - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': - return sinkhorn_log(a, b, M, reg, numItermax=numItermax, + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'greenkhorn': - return greenkhorn(a, b, M, reg, numItermax=numItermax, + return greenkhorn(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_epsilon_scaling': return sinkhorn_epsilon_scaling(a, b, M, reg, - numItermax=numItermax, + numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) @@ -182,7 +185,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, raise ValueError("Unknown method '%s'." % method) -def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, +def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the loss @@ -252,6 +255,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, 'sinkhorn_stabilized', see those function for specific parameters numItermax : int, optional Max number of iterations + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, + i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional @@ -322,17 +328,17 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, if len(b.shape) < 2: if method.lower() == 'sinkhorn': - res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': - res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, + res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) @@ -346,17 +352,17 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, else: if method.lower() == 'sinkhorn': - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': - return sinkhorn_log(a, b, M, reg, numItermax=numItermax, + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) @@ -364,7 +370,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, raise ValueError("Unknown method '%s'." % method) -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None, +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, verbose=False, log=False, warn=True, **kwargs): r""" @@ -407,11 +413,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None, Regularization term >0 numItermax : int, optional Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, the dual vectors must be in logarithm form, + Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, i.e. warmstart = (log_u, log_v), but not (u, v). + stopThr : float, optional + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -552,7 +558,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) -def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None, verbose=False, +def sinkhorn_log(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, verbose=False, log=False, warn=True, **kwargs): r""" Solve the entropic regularization optimal transport problem in log space @@ -594,11 +600,11 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None, ve Regularization term >0 numItermax : int, optional Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, the dual vectors must be in logarithm form, + Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, i.e. warmstart = (log_u, log_v), but not (u, v). + stopThr : float, optional + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -761,7 +767,7 @@ def get_logT(u, v): return nx.exp(get_logT(u, v)) -def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, +def greenkhorn(a, b, M, reg, numItermax=10000, warmstart=None, stopThr=1e-9, verbose=False, log=False, warn=True): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -804,6 +810,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, Regularization term >0 numItermax : int, optional Max number of iterations + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, + i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) log : bool, optional @@ -868,8 +877,11 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, K = nx.exp(-M / reg) - u = nx.full((dim_a,), 1. / dim_a, type_as=K) - v = nx.full((dim_b,), 1. / dim_b, type_as=K) + if warmstart is None: + u = nx.full((dim_a,), 1. / dim_a, type_as=K) + v = nx.full((dim_b,), 1. / dim_b, type_as=K) + else: + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) G = u[:, None] * K * v[None, :] viol = nx.sum(G, axis=1) - a @@ -2872,7 +2884,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, + numIterMax=10000, warmstart=None, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, log=False, warn=True, **kwargs): r''' Solve the entropic regularization optimal transport problem and return the @@ -2911,6 +2923,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', samples weights in the target domain numItermax : int, optional Max number of iterations + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, + i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) isLazy: boolean, optional @@ -2976,7 +2991,10 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', dict_log = {"err": []} log_a, log_b = nx.log(a), nx.log(b) - f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a) + if warmstart is None: + f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a) + else: + f, g = warmstart if isinstance(batchSize, int): bs, bt = batchSize, batchSize @@ -3048,17 +3066,17 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', else: M = dist(X_s, X_t, metric=metric) if log: - pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=True, **kwargs) return pi, log else: - pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=False, **kwargs) return pi def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, isLazy=False, + numIterMax=10000, warmstart=None, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, log=False, warn=True, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical @@ -3101,6 +3119,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', samples weights in the target domain numItermax : int, optional Max number of iterations + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, + i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) isLazy: boolean, optional @@ -3167,7 +3188,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if isLazy: if log: f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, - numIterMax=numIterMax, + numIterMax=numIterMax, + warmstart=warmstart, stopThr=stopThr, isLazy=isLazy, batchSize=batchSize, @@ -3175,7 +3197,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', warn=warn) else: f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, - numIterMax=numIterMax, stopThr=stopThr, + numIterMax=numIterMax, + warmstart=warmstart, + stopThr=stopThr, isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log, warn=warn) @@ -3203,19 +3227,19 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', M = dist(X_s, X_t, metric=metric) if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) return sinkhorn_loss, log else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) return sinkhorn_loss def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, + numIterMax=10000, warmstart=None, stopThr=1e-9, verbose=False, log=False, warn=True, **kwargs): r''' @@ -3286,6 +3310,9 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli samples weights in the target domain numItermax : int, optional Max number of iterations + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, + i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional @@ -3323,20 +3350,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli X_s, X_t = list_to_array(X_s, X_t) nx = get_backend(X_s, X_t) + if warmstart is None: + warmstart_a, warmstart_b = None, None + else: + u, v = warmstart + warmstart_a = (u, u) + warmstart_b = (v, v) if log: sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, + numIterMax=numIterMax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, + numIterMax=numIterMax, warmstart=warmstart_a, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, + numIterMax=numIterMax, warmstart=warmstart_b, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) @@ -3354,17 +3387,20 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli else: sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, stopThr=stopThr, + numIterMax=numIterMax, warmstart=warmstart, + stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, stopThr=stopThr, + numIterMax=numIterMax, warmstart=warmstart_a, + stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, stopThr=stopThr, + numIterMax=numIterMax, warmstart=warmstart_b, + stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) diff --git a/test/test_bregman.py b/test/test_bregman.py index ce1564225..9fb09798f 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -1013,3 +1013,79 @@ def test_convolutional_barycenter_non_square(nx): np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) np.testing.assert_allclose(b, b_np) + +def test_sinkhorn_warmstart(): + m, n = 10, 20 + a = ot.unif(m) + b = ot.unif(n) + + Xs = np.arange(m) * 1.0 + Xt = np.arange(n) * 1.0 + M = ot.dist(Xs.reshape(-1,1), Xt.reshape(-1,1)) + + # Generate warmstart from dual vectors of unregularized OT + _, log = ot.lp.emd(a, b, M, log=True) + warmstart = (log["u"], log["v"]) + + reg = 1 + + # Optimal plan with uniform warmstart + pi_unif, _ = ot.bregman.sinkhorn(a, b, M, reg, method="sinkhorn", warmstart=None, log=True) + # Optimal plan with warmstart generated from unregularized OT + pi_sh, _ = ot.bregman.sinkhorn(a, b, M, reg, method="sinkhorn", warmstart=warmstart, log=True) + pi_sh_log, _ = ot.bregman.sinkhorn(a, b, M, reg, method="sinkhorn_log", warmstart=warmstart, log=True) + pi_sh_stab, _ = ot.bregman.sinkhorn(a, b, M, reg, method="sinkhorn_stabilized", warmstart=warmstart, log=True) + pi_sh_sc, _ = ot.bregman.sinkhorn(a, b, M, reg, method="sinkhorn_epsilon_scaling", warmstart=warmstart, log=True) + + np.testing.assert_allclose(pi_unif, pi_sh, atol=1e-05) + np.testing.assert_allclose(pi_unif, pi_sh_log, atol=1e-05) + np.testing.assert_allclose(pi_unif, pi_sh_stab, atol=1e-05) + np.testing.assert_allclose(pi_unif, pi_sh_sc, atol=1e-05) + +def test_empirical_sinkhorn_warmstart(): + m, n = 10, 20 + Xs = np.arange(m).reshape(-1,1) * 1.0 + Xt = np.arange(n).reshape(-1,1) * 1.0 + M = ot.dist(Xs, Xt) + + # Generate warmstart from dual vectors of unregularized OT + a = ot.unif(m) + b = ot.unif(n) + _, log = ot.lp.emd(a, b, M, log=True) + warmstart = (log["u"], log["v"]) + + reg = 1 + + # Optimal plan with uniform warmstart + f, g, _ = ot.bregman.empirical_sinkhorn(X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=None, log=True) + pi_unif = np.exp(f[:, None] + g[None, :] - M / reg) + # Optimal plan with warmstart generated from unregularized OT + f, g, _ = ot.bregman.empirical_sinkhorn(X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=warmstart, log=True) + pi_ws_lazy = np.exp(f[:, None] + g[None, :] - M / reg) + pi_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn(X_s=Xs, X_t=Xt, reg=reg, isLazy=False, warmstart=warmstart, log=True) + + np.testing.assert_allclose(pi_unif, pi_ws_lazy, atol=1e-05) + np.testing.assert_allclose(pi_unif, pi_ws_not_lazy, atol=1e-05) + +def test_empirical_sinkhorn_divergence_warmstart(): + m, n = 10, 20 + Xs = np.arange(m).reshape(-1,1) * 1.0 + Xt = np.arange(n).reshape(-1,1) * 1.0 + M = ot.dist(Xs, Xt) + + # Generate warmstart from dual vectors of unregularized OT + a = ot.unif(m) + b = ot.unif(n) + _, log = ot.lp.emd(a, b, M, log=True) + warmstart = (log["u"], log["v"]) + + reg = 1 + + # Optimal plan with uniform warmstart + sd_unif, _ = ot.bregman.empirical_sinkhorn_divergence(X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=None, log=True) + # Optimal plan with warmstart generated from unregularized OT + sd_ws_lazy, _ = ot.bregman.empirical_sinkhorn_divergence(X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=warmstart, log=True) + sd_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn_divergence(X_s=Xs, X_t=Xt, reg=reg, isLazy=False, warmstart=warmstart, log=True) + + np.testing.assert_allclose(sd_unif, sd_ws_lazy, atol=1e-05) + np.testing.assert_allclose(sd_unif, sd_ws_not_lazy, atol=1e-05) \ No newline at end of file From cdd9373cb4ca8384750d3324b78fbb0b31ce4f11 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Wed, 22 Feb 2023 12:29:38 +0100 Subject: [PATCH 3/8] Add the number of the PR --- RELEASES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 9a377f140..4eb17d024 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -9,7 +9,7 @@ - Added Free Support Sinkhorn Barycenter + example (PR #387) - New API for OT solver using function `ot.solve` (PR #388) - Backend version of `ot.partial` and `ot.smooth` (PR #388) -- Added argument for warmstart of dual vectors in Sinkhorn-based methods in `ot.bregman` (PR #) +- Added argument for warmstart of dual vectors in Sinkhorn-based methods in `ot.bregman` (PR #437) #### Closed issues From f3d36b2705013409ac69b346585e311bc25fcfb7 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Wed, 22 Feb 2023 12:34:50 +0100 Subject: [PATCH 4/8] [WIP] CO-Optimal Transport --- README.md | 4 +- RELEASES.md | 1 + ot/__init__.py | 4 +- ot/coot.py | 408 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 415 insertions(+), 2 deletions(-) create mode 100644 ot/coot.py diff --git a/README.md b/README.md index 7c9475b80..cf9fcbdc3 100644 --- a/README.md +++ b/README.md @@ -292,4 +292,6 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. -[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. \ No newline at end of file +[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + +[44] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 4eb17d024..ef3a832fb 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,6 +10,7 @@ - New API for OT solver using function `ot.solve` (PR #388) - Backend version of `ot.partial` and `ot.smooth` (PR #388) - Added argument for warmstart of dual vectors in Sinkhorn-based methods in `ot.bregman` (PR #437) +- Added CO-Optimal Transport solver and example (PR #) #### Closed issues diff --git a/ot/__init__.py b/ot/__init__.py index 0b55e0c56..8676990f1 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -36,6 +36,7 @@ from . import factored from . import solvers from . import gaussian +from . import coot # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -49,6 +50,7 @@ from .weak import weak_optimal_transport from .factored import factored_optimal_transport from .solvers import solve +from .coot import co_optimal_transport, co_optimal_transport2 # utils functions from .utils import dist, unif, tic, toc, toq @@ -64,4 +66,4 @@ 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', 'factored_optimal_transport', 'solve', - 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers'] + 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'coot', 'co_optimal_transport', 'co_optimal_transport2'] diff --git a/ot/coot.py b/ot/coot.py new file mode 100644 index 000000000..9d7e5d37e --- /dev/null +++ b/ot/coot.py @@ -0,0 +1,408 @@ +# -*- coding: utf-8 -*- +""" +Fused CO-Optimal Transport and entropic Fused CO-Optimal Transport solvers +""" + +# Author: Quang Huy Tran +# +# License: MIT License + +import numpy as np +from functools import partial +from .lp import emd +from .utils import list_to_array +from .backend import get_backend +from .bregman import sinkhorn + +def co_optimal_transport(X, Y, px=(None, None), py=(None, None), eps=(0, 0), alpha=(1, 1), D=(None, None), + dict_init=None, log=False, verbose=False, early_stopping_tol=1e-6, eval_bcd=1, tol_bcd=1e-7, + nits_bcd=100, nits_sinkhorn=500, tol_sinkhorn=1e-7, method_sinkhorn="sinkhorn"): + r""" + Return the sample and feature transport plans between + :math:`(\mathbf{X}, \mathbf{p}_{xs}, \mathbf{p}_{xf})` and + :math:`(\mathbf{Y}, \mathbf{p}_{ys}, \mathbf{p}_{yf})`. + + The function solves the following optimization problem: + + .. math:: + \mathbf{COOT}_{\varepsilon} = \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}} + \quad \sum_{i,j,k,l} + (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l} + + \alpha_1 \sum_{i,j} \mathbf{P}_{i,j} \mathbf{D^{(s)}}_{i, j} + + \alpha_2 \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{D^{(f)}}_{k, l} + + \varepsilon_1 \mathbf{KL}(\mathbf{P} | \mathbf{p}_{xs} \mathbf{p}_{ys}^T) + + \varepsilon_2 \mathbf{KL}(\mathbf{Q} | \mathbf{p}_{xf} \mathbf{p}_{yf}^T) + + Where : + + - :math:`\mathbf{X}`: Data matrix in the source space + - :math:`\mathbf{Y}`: Data matrix in the target space + - :math:`\mathbf{D^{(s)}}`: Sample matrix + - :math:`\mathbf{D^{(f)}}`: Feature matrix + - :math:`\mathbf{p}_{xs}`: distribution of the samples in the source space + - :math:`\mathbf{p}_{xf}`: distribution of the features in the source space + - :math:`\mathbf{p}_{ys}`: distribution of the samples in the target space + - :math:`\mathbf{p}_{yf}`: distribution of the features in the target space + + .. note:: This function allows epsilons to be zero. In that case, the EMD solver will be used. + + Parameters + ---------- + X : (sx, fx) array-like, float + First input matrix. + Y : (sy, fy) array-like, float + Second input matrix. + px : (sx, fx) tuple, float, optional (default = None) + Histogram assigned on rows (samples) and columns (features) of X. + Uniform distribution by default. + py : (sy, fy) tuple, float, optional (default = None) + Histogram assigned on rows (samples) and columns (features) of Y. + Uniform distribution by default. + eps : (scalar, scalar) tuple, float or int + Regularisation parameters for entropic approximation of sample and feature couplings. + Allow the case where eps contains 0. In that case, the EMD solver is used instead of + Sinkhorn solver. + alpha : (scalar, scalar) tuple, float or int, optional (default = (1,1)) + Interpolation parameter for fused CO-Optimal Transport w.r.t the sample and feature couplings. + D : tuple of matrices (sx, sy) and (fx, fy), float, optional (default = None) + Sample and feature matrices, in case of fused CO-Optimal Transport. + dict_init : dictionary, optional (default = None) + Dictionary containing 4 keys: + + "duals_sample" and "duals_feature" whose values are + tuples of 2 vectors of size (sx, sy) and (fx, fy). + Initialization of sample and feature dual vectors if using Sinkhorn algorithm. + Zero vectors by default. + + "pi_sample" and "pi_feature" whose values are matrices of size (sx, sy) and (fx, fy). + Initialization of sample and feature couplings. + Uniform distributions by default. + log : bool, optional (default = False) + If True then the cost and 4 dual vectors are recorded. + verbose : bool, optional (default = False) + If True then print the cost. + early_stopping_tol : float, optional (default = 1e-6) + Tolerance for the early stopping, if the absolute value of the difference between the + last 2 recorded costs is smaller than the tolerance, then stop training. + eval_bcd : int, optional (default = 1) + Multiplier of iteration at which the cost is calculated. For example, + if eval_bcd = 10, then the cost is calculated at iterations 10, 20, 30, etc... + tol_bcd : float, optional (default = 1e-7) + Tolerance of BCD scheme. + nits_bcd : int, optional (default = 100) + Number of BCD iterations. + tol_sinkhorn : float, optional (default = 1e-7) + Tolerance of Sinkhorn algorithm. + nits_sinkhorn : int, optional (default = 100) + Number of Sinkhorn iterations. + method_sinkhorn : string, optional (default = "sinkhorn") + Method used in POT's Sinkhorn solver. Only support "sinkhorn" and "sinkhorn_log". + + Returns + ------- + pi_sample : (sx, sy) array-like, float + Sample coupling matrix. + pi_feature : (fx, fy) array-like, float + Feature coupling matrix. + + if log is True, then return additionally a dictionary whose keys are: + duals_sample : (sx, sy) tuple, float + Pair of dual vectors when solving OT problem w.r.t the sample coupling. + duals_feature : (fx, fy) tuple, float + Pair of dual vectors when solving OT problem w.r.t the feature coupling. + log_cost : list, float + List of costs (without taking into account the entropic regularization terms). + log_ent_cost : list, float + List of entropic costs. + + References + ---------- + .. [44] I. Redko, T. Vayer, R. Flamary, and N. Courty, CO-Optimal Transport, + Advances in Neural Information Processing Systems, 33 (2020). + """ + + def compute_kl(p, log_q): + kl = nx.sum(p * nx.log(p + 1.0 * (p==0))) - nx.sum(p * log_q) + return kl + + def emd_solver(cost, p1_np, p2_np): + cost_np = nx.to_numpy(cost) + pi_np, log = emd(p1_np, p2_np, cost_np, log=True) + + f1 = nx.from_numpy(log["u"], type_as=cost) + f2 = nx.from_numpy(log["v"], type_as=cost) + pi = nx.from_numpy(pi_np, type_as=cost) + + return pi, (f1, f2) + + def get_cost(ot_cost, pi_sample, pi_feature, log_pxy_samp, log_pxy_feat, D_samp, alpha_samp, eps): + eps_samp, eps_feat = eps + + # UCOOT part + cost = nx.sum(ot_cost * pi_feature) + if alpha_samp != 0: + cost = cost + alpha_samp * nx.sum(D_samp * pi_sample) + + # Entropic part + ent_cost = cost + if eps_samp != 0: + ent_cost = ent_cost + eps_samp * compute_kl(pi_sample, log_pxy_samp) + if eps_feat != 0: + ent_cost = ent_cost + eps_feat * compute_kl(pi_feature, log_pxy_feat) + + return cost, ent_cost + + ######################################## + ############# Main function ############ + ######################################## + + if method_sinkhorn not in ["sinkhorn", "sinkhorn_log"]: + raise ValueError("Method {} is not supported in CO-Optimal Transport.".format(method_sinkhorn)) + + X, Y = list_to_array(X, Y) + nx = get_backend(X, Y) + + # constant input variables + eps_samp, eps_feat = eps + alpha_samp, alpha_feat = alpha + if D is None: + D = (None, None) + D_samp, D_feat = D + if D_samp is None or alpha_samp == 0: + D_samp, alpha_samp = 0, 0 + if D_feat is None or alpha_feat == 0: + D_feat, alpha_feat = 0, 0 + + sx, fx = X.shape # s for sample and f for feature + sy, fy = Y.shape # s for sample and f for feature + + # measures on rows and columns + px_samp, px_feat = px + py_samp, py_feat = py + + if px_samp is None: + px_samp = nx.ones(sx, type_as=X) / sx + px_samp_np = np.ones(sx) / sx # create + else: + px_samp_np = nx.to_numpy(px_samp) + + if px_feat is None: + px_feat = nx.ones(fx, type_as=X) / fx + px_feat_np = np.ones(fx) / fx + else: + px_feat_np = nx.to_numpy(px_feat) + + if py_samp is None: + py_samp = nx.ones(sy, type_as=Y) / sy + py_samp_np = np.ones(sy) / sy + else: + py_samp_np = nx.to_numpy(py_samp) + + if py_feat is None: + py_feat = nx.ones(fy, type_as=Y) / fy + py_feat_np = np.ones(fy) / fy + else: + py_feat_np = nx.to_numpy(py_feat) + + pxy_samp = px_samp[:, None] * py_samp[None, :] + pxy_feat = px_feat[:, None] * py_feat[None, :] + + # precalculate cost constants + XY_sqr = (X ** 2 @ px_feat)[:,None] + (Y ** 2 @ py_feat)[None,:] + alpha_samp * D_samp + XY_sqr_T = ((X.T)**2 @ px_samp)[:,None] + ((Y.T)**2 @ py_samp)[None,:] + alpha_feat * D_feat + + # initialise coupling and dual vectors + if dict_init is None: + pi_sample, pi_feature = pxy_samp, pxy_feat # size sx x sy and size fx x fy + duals_samp = (nx.zeros(sx, type_as=X), nx.zeros(sy, type_as=Y)) # shape sx, sy + duals_feat = (nx.zeros(fx, type_as=X), nx.zeros(fy, type_as=Y)) # shape fx, fy + else: + pi_sample, pi_feature = dict_init["pi_sample"], dict_init["pi_feature"] + duals_samp, duals_feat = dict_init["duals_sample"], dict_init["duals_feature"] + + # create shortcuts of functions + self_sinkhorn = partial(sinkhorn, method=method_sinkhorn, numItermax=nits_sinkhorn, stopThr=tol_sinkhorn, log=True) + self_get_cost = partial(get_cost, log_pxy_samp=nx.log(pxy_samp), log_pxy_feat=nx.log(pxy_feat), + D_samp=D_samp, alpha_samp=alpha_samp, eps=eps) + + # initialise log + log_coot = [] + log_entropic_coot = [float("inf")] + err = tol_bcd + 1e-3 + + for idx in range(nits_bcd): + pi_sample_prev = nx.copy(pi_sample) + + # update pi_sample (sample coupling) + ot_cost = XY_sqr - 2 * X @ pi_feature @ Y.T # size sx x sy + if eps_samp > 0: + pi_sample, dict_log = self_sinkhorn(a=px_samp, b=py_samp, M=ot_cost, reg=eps_samp, warmstart=duals_samp) + duals_samp = (nx.log(dict_log["u"]), nx.log(dict_log["v"])) + elif eps_samp == 0: + pi_sample, duals_samp = emd_solver(ot_cost, px_samp_np, py_samp_np) + + # update pi_feature (feature coupling) + ot_cost = XY_sqr_T - 2 * X.T @ pi_sample @ Y # size fx x fy + if eps_feat > 0: + pi_feature, dict_log = self_sinkhorn(a=px_feat, b=py_feat, M=ot_cost, reg=eps_feat, warmstart=duals_feat) + duals_feat = (nx.log(dict_log["u"]), nx.log(dict_log["v"])) + elif eps_feat == 0: + pi_feature, duals_feat = emd_solver(ot_cost, px_feat_np, py_feat_np) + + if idx % eval_bcd == 0: + # update error + err = nx.sum(nx.max(pi_sample - pi_sample_prev)) + coot, ent_coot = self_get_cost(ot_cost, pi_sample, pi_feature) + log_coot.append(coot) + log_entropic_coot.append(ent_coot) + + if err < tol_bcd or abs(log_entropic_coot[-2] - log_entropic_coot[-1]) < early_stopping_tol: + break + + if verbose: + print("Unregularized CO-Optimal Transport cost at iteration {}: {}".format(idx+1, coot)) + + # sanity check + if nx.sum(nx.isnan(pi_sample)) > 0 or nx.sum(nx.isnan(pi_feature)) > 0: + print("There is NaN in coupling.") + + if log: + dict_log = {"duals_sample": duals_samp, \ + "duals_feature": duals_feat, \ + "log_cost": log_coot, \ + "log_ent_cost": log_entropic_coot[1:]} + + return pi_sample, pi_feature, dict_log + + else: + return pi_sample, pi_feature + +def co_optimal_transport2(X, Y, px=(None, None), py=(None, None), eps=(1e-2, 1e-2), alpha=(1, 1), D=(None, None), + dict_init=None, log=False, verbose=False, early_stopping_tol=1e-6, eval_bcd=2, tol_bcd=1e-7, + nits_bcd=100, nits_sinkhorn=500, tol_sinkhorn=1e-7, method_sinkhorn="sinkhorn"): + """ + Return the COOT distance and its entropic approximation between + :math:`(\mathbf{X}, \mathbf{p}_{xs}, \mathbf{p}_{xf})` and + :math:`(\mathbf{Y}, \mathbf{p}_{ys}, \mathbf{p}_{yf})`. + + The function solves the following optimization problem: + + .. math:: + \mathbf{COOT}_{\varepsilon} = \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}} + \quad \sum_{i,j,k,l} + (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l} + + \alpha_1 \sum_{i,j} \mathbf{P}_{i,j} \mathbf{D^{(s)}}_{i, j} + + \alpha_2 \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{D^{(f)}}_{k, l} + + \varepsilon_1 \mathbf{KL}(\mathbf{P} | \mathbf{p}_{xs} \mathbf{p}_{ys}^T) + + \varepsilon_2 \mathbf{KL}(\mathbf{Q} | \mathbf{p}_{xf} \mathbf{p}_{yf}^T) + + Where : + + - :math:`\mathbf{X}`: Data matrix in the source space + - :math:`\mathbf{Y}`: Data matrix in the target space + - :math:`\mathbf{D^{(s)}}`: Sample matrix + - :math:`\mathbf{D^{(f)}}`: Feature matrix + - :math:`\mathbf{p}_{xs}`: distribution of the samples in the source space + - :math:`\mathbf{p}_{xf}`: distribution of the features in the source space + - :math:`\mathbf{p}_{ys}`: distribution of the samples in the target space + - :math:`\mathbf{p}_{yf}`: distribution of the features in the target space + + .. note:: This function allows epsilons to be zero. In that case, the EMD solver will be used. + + Parameters + ---------- + X : (sx, fx) array-like, float + First input matrix. + Y : (sy, fy) array-like, float + Second input matrix. + px : (sx, fx) tuple, float, optional (default = None) + Histogram assigned on rows (samples) and columns (features) of X. + Uniform distribution by default. + py : (sy, fy) tuple, float, optional (default = None) + Histogram assigned on rows (samples) and columns (features) of Y. + Uniform distribution by default. + eps : (scalar, scalar) tuple, float or int + Regularisation parameters for entropic approximation of sample and feature couplings. + Allow the case where eps contains 0. In that case, the EMD solver is used instead of + Sinkhorn solver. + alpha : (scalar, scalar) tuple, float or int, optional (default = (1,1)) + Interpolation parameter for fused CO-Optimal Transport w.r.t the sample and feature couplings. + D : tuple of matrices (sx, sy) and (fx, fy), float, optional (default = (None, None)) + Sample and feature matrices, in case of fused CO-Optimal Transport. + dict_init : dictionary, optional (default = None) + Dictionary containing 4 keys: + + "duals_sample" and "duals_feature" whose values are + tuples of 2 vectors of size (sx, sy) and (fx, fy). + Initialization of sample and feature dual vectors if using Sinkhorn algorithm. + Zero vectors by default. + + "pi_sample" and "pi_feature" whose values are matrices of size (sx, sy) and (fx, fy). + Initialization of sample and feature couplings. + Uniform distributions by default. + log : bool, optional (default = False) + If True then the cost and 4 dual vectors are recorded. + verbose : bool, optional (default = False) + If True then print the cost. + early_stopping_tol : float, optional (default = 1e-6) + Tolerance for the early stopping, if the absolute value of the difference between the + last 2 recorded costs is smaller than the tolerance, then stop training. + eval_bcd : int, optional (default = 1) + Multiplier of iteration at which the cost is calculated. For example, + if eval_bcd = 10, then the cost is calculated at iterations 10, 20, 30, etc... + tol_bcd : float, optional (default = 1e-7) + Tolerance of BCD scheme. + nits_bcd : int, optional (default = 100) + Number of BCD iterations. + tol_sinkhorn : float, optional (default = 1e-7) + Tolerance of Sinkhorn algorithm. + nits_sinkhorn : int, optional (default = 100) + Number of Sinkhorn iterations. + method_sinkhorn : string, optional (default = "sinkhorn") + Method used in POT's Sinkhorn solver. Only support "sinkhorn" and "sinkhorn_log". + + Returns + ------- + Entropic COOT : float + + If log is True, then also return the dictionary output of COOT. + + References + ---------- + .. [44] I. Redko, T. Vayer, R. Flamary, and N. Courty, CO-Optimal Transport, + Advances in Neural Information Processing Systems, 33 (2020). + """ + + pi_sample, pi_feature, dict_log = \ + co_optimal_transport(X, Y, px, py, eps, alpha, D, dict_init, True, verbose, early_stopping_tol, \ + eval_bcd, tol_bcd, nits_bcd, nits_sinkhorn, tol_sinkhorn, method_sinkhorn) + + X, Y = list_to_array(X, Y) + nx = get_backend(X, Y) + + sx, fx = X.shape + sy, fy = Y.shape + + px_samp, px_feat = px + py_samp, py_feat = py + + if px_samp is None: + px_samp = nx.ones(sx, type_as=X) / sx + if px_feat is None: + px_feat = nx.ones(fx, type_as=X) / fx + if py_samp is None: + py_samp = nx.ones(sy, type_as=Y) / sy + if py_feat is None: + py_feat = nx.ones(fy, type_as=Y) / fy + + vx_samp, vy_samp = dict_log["duals_sample"] + vx_feat, vy_feat = dict_log["duals_feature"] + + gradX = 2 * X * (px_samp[:,None] * px_feat[None,:]) - 2 * pi_sample @ Y @ pi_feature.T # shape (sx, fx) + gradY = 2 * Y * (py_samp[:,None] * py_feat[None,:]) - 2 * pi_sample.T @ X @ pi_feature # shape (sy, fy) + + ent_coot = dict_log["log_ent_cost"][-1] + ent_coot = nx.set_gradients(ent_coot, (px_samp, px_feat, py_samp, py_feat, X, Y), \ + (vx_samp, vx_feat, vy_samp, vy_feat, gradX, gradY)) + + if log: + return ent_coot, dict_log + + else: + return ent_coot \ No newline at end of file From 53e50e583b68ac29ab2501ec8df29ab27f1389df Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Wed, 22 Feb 2023 14:39:46 +0100 Subject: [PATCH 5/8] Revert "[WIP] CO-Optimal Transport" This reverts commit f3d36b2705013409ac69b346585e311bc25fcfb7. --- README.md | 4 +- RELEASES.md | 1 - ot/__init__.py | 4 +- ot/coot.py | 408 ------------------------------------------------- 4 files changed, 2 insertions(+), 415 deletions(-) delete mode 100644 ot/coot.py diff --git a/README.md b/README.md index cf9fcbdc3..7c9475b80 100644 --- a/README.md +++ b/README.md @@ -292,6 +292,4 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. -[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. - -[44] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33. \ No newline at end of file +[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index ef3a832fb..4eb17d024 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,7 +10,6 @@ - New API for OT solver using function `ot.solve` (PR #388) - Backend version of `ot.partial` and `ot.smooth` (PR #388) - Added argument for warmstart of dual vectors in Sinkhorn-based methods in `ot.bregman` (PR #437) -- Added CO-Optimal Transport solver and example (PR #) #### Closed issues diff --git a/ot/__init__.py b/ot/__init__.py index 8676990f1..0b55e0c56 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -36,7 +36,6 @@ from . import factored from . import solvers from . import gaussian -from . import coot # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -50,7 +49,6 @@ from .weak import weak_optimal_transport from .factored import factored_optimal_transport from .solvers import solve -from .coot import co_optimal_transport, co_optimal_transport2 # utils functions from .utils import dist, unif, tic, toc, toq @@ -66,4 +64,4 @@ 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', 'factored_optimal_transport', 'solve', - 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'coot', 'co_optimal_transport', 'co_optimal_transport2'] + 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers'] diff --git a/ot/coot.py b/ot/coot.py deleted file mode 100644 index 9d7e5d37e..000000000 --- a/ot/coot.py +++ /dev/null @@ -1,408 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Fused CO-Optimal Transport and entropic Fused CO-Optimal Transport solvers -""" - -# Author: Quang Huy Tran -# -# License: MIT License - -import numpy as np -from functools import partial -from .lp import emd -from .utils import list_to_array -from .backend import get_backend -from .bregman import sinkhorn - -def co_optimal_transport(X, Y, px=(None, None), py=(None, None), eps=(0, 0), alpha=(1, 1), D=(None, None), - dict_init=None, log=False, verbose=False, early_stopping_tol=1e-6, eval_bcd=1, tol_bcd=1e-7, - nits_bcd=100, nits_sinkhorn=500, tol_sinkhorn=1e-7, method_sinkhorn="sinkhorn"): - r""" - Return the sample and feature transport plans between - :math:`(\mathbf{X}, \mathbf{p}_{xs}, \mathbf{p}_{xf})` and - :math:`(\mathbf{Y}, \mathbf{p}_{ys}, \mathbf{p}_{yf})`. - - The function solves the following optimization problem: - - .. math:: - \mathbf{COOT}_{\varepsilon} = \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}} - \quad \sum_{i,j,k,l} - (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l} - + \alpha_1 \sum_{i,j} \mathbf{P}_{i,j} \mathbf{D^{(s)}}_{i, j} - + \alpha_2 \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{D^{(f)}}_{k, l} - + \varepsilon_1 \mathbf{KL}(\mathbf{P} | \mathbf{p}_{xs} \mathbf{p}_{ys}^T) - + \varepsilon_2 \mathbf{KL}(\mathbf{Q} | \mathbf{p}_{xf} \mathbf{p}_{yf}^T) - - Where : - - - :math:`\mathbf{X}`: Data matrix in the source space - - :math:`\mathbf{Y}`: Data matrix in the target space - - :math:`\mathbf{D^{(s)}}`: Sample matrix - - :math:`\mathbf{D^{(f)}}`: Feature matrix - - :math:`\mathbf{p}_{xs}`: distribution of the samples in the source space - - :math:`\mathbf{p}_{xf}`: distribution of the features in the source space - - :math:`\mathbf{p}_{ys}`: distribution of the samples in the target space - - :math:`\mathbf{p}_{yf}`: distribution of the features in the target space - - .. note:: This function allows epsilons to be zero. In that case, the EMD solver will be used. - - Parameters - ---------- - X : (sx, fx) array-like, float - First input matrix. - Y : (sy, fy) array-like, float - Second input matrix. - px : (sx, fx) tuple, float, optional (default = None) - Histogram assigned on rows (samples) and columns (features) of X. - Uniform distribution by default. - py : (sy, fy) tuple, float, optional (default = None) - Histogram assigned on rows (samples) and columns (features) of Y. - Uniform distribution by default. - eps : (scalar, scalar) tuple, float or int - Regularisation parameters for entropic approximation of sample and feature couplings. - Allow the case where eps contains 0. In that case, the EMD solver is used instead of - Sinkhorn solver. - alpha : (scalar, scalar) tuple, float or int, optional (default = (1,1)) - Interpolation parameter for fused CO-Optimal Transport w.r.t the sample and feature couplings. - D : tuple of matrices (sx, sy) and (fx, fy), float, optional (default = None) - Sample and feature matrices, in case of fused CO-Optimal Transport. - dict_init : dictionary, optional (default = None) - Dictionary containing 4 keys: - + "duals_sample" and "duals_feature" whose values are - tuples of 2 vectors of size (sx, sy) and (fx, fy). - Initialization of sample and feature dual vectors if using Sinkhorn algorithm. - Zero vectors by default. - + "pi_sample" and "pi_feature" whose values are matrices of size (sx, sy) and (fx, fy). - Initialization of sample and feature couplings. - Uniform distributions by default. - log : bool, optional (default = False) - If True then the cost and 4 dual vectors are recorded. - verbose : bool, optional (default = False) - If True then print the cost. - early_stopping_tol : float, optional (default = 1e-6) - Tolerance for the early stopping, if the absolute value of the difference between the - last 2 recorded costs is smaller than the tolerance, then stop training. - eval_bcd : int, optional (default = 1) - Multiplier of iteration at which the cost is calculated. For example, - if eval_bcd = 10, then the cost is calculated at iterations 10, 20, 30, etc... - tol_bcd : float, optional (default = 1e-7) - Tolerance of BCD scheme. - nits_bcd : int, optional (default = 100) - Number of BCD iterations. - tol_sinkhorn : float, optional (default = 1e-7) - Tolerance of Sinkhorn algorithm. - nits_sinkhorn : int, optional (default = 100) - Number of Sinkhorn iterations. - method_sinkhorn : string, optional (default = "sinkhorn") - Method used in POT's Sinkhorn solver. Only support "sinkhorn" and "sinkhorn_log". - - Returns - ------- - pi_sample : (sx, sy) array-like, float - Sample coupling matrix. - pi_feature : (fx, fy) array-like, float - Feature coupling matrix. - - if log is True, then return additionally a dictionary whose keys are: - duals_sample : (sx, sy) tuple, float - Pair of dual vectors when solving OT problem w.r.t the sample coupling. - duals_feature : (fx, fy) tuple, float - Pair of dual vectors when solving OT problem w.r.t the feature coupling. - log_cost : list, float - List of costs (without taking into account the entropic regularization terms). - log_ent_cost : list, float - List of entropic costs. - - References - ---------- - .. [44] I. Redko, T. Vayer, R. Flamary, and N. Courty, CO-Optimal Transport, - Advances in Neural Information Processing Systems, 33 (2020). - """ - - def compute_kl(p, log_q): - kl = nx.sum(p * nx.log(p + 1.0 * (p==0))) - nx.sum(p * log_q) - return kl - - def emd_solver(cost, p1_np, p2_np): - cost_np = nx.to_numpy(cost) - pi_np, log = emd(p1_np, p2_np, cost_np, log=True) - - f1 = nx.from_numpy(log["u"], type_as=cost) - f2 = nx.from_numpy(log["v"], type_as=cost) - pi = nx.from_numpy(pi_np, type_as=cost) - - return pi, (f1, f2) - - def get_cost(ot_cost, pi_sample, pi_feature, log_pxy_samp, log_pxy_feat, D_samp, alpha_samp, eps): - eps_samp, eps_feat = eps - - # UCOOT part - cost = nx.sum(ot_cost * pi_feature) - if alpha_samp != 0: - cost = cost + alpha_samp * nx.sum(D_samp * pi_sample) - - # Entropic part - ent_cost = cost - if eps_samp != 0: - ent_cost = ent_cost + eps_samp * compute_kl(pi_sample, log_pxy_samp) - if eps_feat != 0: - ent_cost = ent_cost + eps_feat * compute_kl(pi_feature, log_pxy_feat) - - return cost, ent_cost - - ######################################## - ############# Main function ############ - ######################################## - - if method_sinkhorn not in ["sinkhorn", "sinkhorn_log"]: - raise ValueError("Method {} is not supported in CO-Optimal Transport.".format(method_sinkhorn)) - - X, Y = list_to_array(X, Y) - nx = get_backend(X, Y) - - # constant input variables - eps_samp, eps_feat = eps - alpha_samp, alpha_feat = alpha - if D is None: - D = (None, None) - D_samp, D_feat = D - if D_samp is None or alpha_samp == 0: - D_samp, alpha_samp = 0, 0 - if D_feat is None or alpha_feat == 0: - D_feat, alpha_feat = 0, 0 - - sx, fx = X.shape # s for sample and f for feature - sy, fy = Y.shape # s for sample and f for feature - - # measures on rows and columns - px_samp, px_feat = px - py_samp, py_feat = py - - if px_samp is None: - px_samp = nx.ones(sx, type_as=X) / sx - px_samp_np = np.ones(sx) / sx # create - else: - px_samp_np = nx.to_numpy(px_samp) - - if px_feat is None: - px_feat = nx.ones(fx, type_as=X) / fx - px_feat_np = np.ones(fx) / fx - else: - px_feat_np = nx.to_numpy(px_feat) - - if py_samp is None: - py_samp = nx.ones(sy, type_as=Y) / sy - py_samp_np = np.ones(sy) / sy - else: - py_samp_np = nx.to_numpy(py_samp) - - if py_feat is None: - py_feat = nx.ones(fy, type_as=Y) / fy - py_feat_np = np.ones(fy) / fy - else: - py_feat_np = nx.to_numpy(py_feat) - - pxy_samp = px_samp[:, None] * py_samp[None, :] - pxy_feat = px_feat[:, None] * py_feat[None, :] - - # precalculate cost constants - XY_sqr = (X ** 2 @ px_feat)[:,None] + (Y ** 2 @ py_feat)[None,:] + alpha_samp * D_samp - XY_sqr_T = ((X.T)**2 @ px_samp)[:,None] + ((Y.T)**2 @ py_samp)[None,:] + alpha_feat * D_feat - - # initialise coupling and dual vectors - if dict_init is None: - pi_sample, pi_feature = pxy_samp, pxy_feat # size sx x sy and size fx x fy - duals_samp = (nx.zeros(sx, type_as=X), nx.zeros(sy, type_as=Y)) # shape sx, sy - duals_feat = (nx.zeros(fx, type_as=X), nx.zeros(fy, type_as=Y)) # shape fx, fy - else: - pi_sample, pi_feature = dict_init["pi_sample"], dict_init["pi_feature"] - duals_samp, duals_feat = dict_init["duals_sample"], dict_init["duals_feature"] - - # create shortcuts of functions - self_sinkhorn = partial(sinkhorn, method=method_sinkhorn, numItermax=nits_sinkhorn, stopThr=tol_sinkhorn, log=True) - self_get_cost = partial(get_cost, log_pxy_samp=nx.log(pxy_samp), log_pxy_feat=nx.log(pxy_feat), - D_samp=D_samp, alpha_samp=alpha_samp, eps=eps) - - # initialise log - log_coot = [] - log_entropic_coot = [float("inf")] - err = tol_bcd + 1e-3 - - for idx in range(nits_bcd): - pi_sample_prev = nx.copy(pi_sample) - - # update pi_sample (sample coupling) - ot_cost = XY_sqr - 2 * X @ pi_feature @ Y.T # size sx x sy - if eps_samp > 0: - pi_sample, dict_log = self_sinkhorn(a=px_samp, b=py_samp, M=ot_cost, reg=eps_samp, warmstart=duals_samp) - duals_samp = (nx.log(dict_log["u"]), nx.log(dict_log["v"])) - elif eps_samp == 0: - pi_sample, duals_samp = emd_solver(ot_cost, px_samp_np, py_samp_np) - - # update pi_feature (feature coupling) - ot_cost = XY_sqr_T - 2 * X.T @ pi_sample @ Y # size fx x fy - if eps_feat > 0: - pi_feature, dict_log = self_sinkhorn(a=px_feat, b=py_feat, M=ot_cost, reg=eps_feat, warmstart=duals_feat) - duals_feat = (nx.log(dict_log["u"]), nx.log(dict_log["v"])) - elif eps_feat == 0: - pi_feature, duals_feat = emd_solver(ot_cost, px_feat_np, py_feat_np) - - if idx % eval_bcd == 0: - # update error - err = nx.sum(nx.max(pi_sample - pi_sample_prev)) - coot, ent_coot = self_get_cost(ot_cost, pi_sample, pi_feature) - log_coot.append(coot) - log_entropic_coot.append(ent_coot) - - if err < tol_bcd or abs(log_entropic_coot[-2] - log_entropic_coot[-1]) < early_stopping_tol: - break - - if verbose: - print("Unregularized CO-Optimal Transport cost at iteration {}: {}".format(idx+1, coot)) - - # sanity check - if nx.sum(nx.isnan(pi_sample)) > 0 or nx.sum(nx.isnan(pi_feature)) > 0: - print("There is NaN in coupling.") - - if log: - dict_log = {"duals_sample": duals_samp, \ - "duals_feature": duals_feat, \ - "log_cost": log_coot, \ - "log_ent_cost": log_entropic_coot[1:]} - - return pi_sample, pi_feature, dict_log - - else: - return pi_sample, pi_feature - -def co_optimal_transport2(X, Y, px=(None, None), py=(None, None), eps=(1e-2, 1e-2), alpha=(1, 1), D=(None, None), - dict_init=None, log=False, verbose=False, early_stopping_tol=1e-6, eval_bcd=2, tol_bcd=1e-7, - nits_bcd=100, nits_sinkhorn=500, tol_sinkhorn=1e-7, method_sinkhorn="sinkhorn"): - """ - Return the COOT distance and its entropic approximation between - :math:`(\mathbf{X}, \mathbf{p}_{xs}, \mathbf{p}_{xf})` and - :math:`(\mathbf{Y}, \mathbf{p}_{ys}, \mathbf{p}_{yf})`. - - The function solves the following optimization problem: - - .. math:: - \mathbf{COOT}_{\varepsilon} = \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}} - \quad \sum_{i,j,k,l} - (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l} - + \alpha_1 \sum_{i,j} \mathbf{P}_{i,j} \mathbf{D^{(s)}}_{i, j} - + \alpha_2 \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{D^{(f)}}_{k, l} - + \varepsilon_1 \mathbf{KL}(\mathbf{P} | \mathbf{p}_{xs} \mathbf{p}_{ys}^T) - + \varepsilon_2 \mathbf{KL}(\mathbf{Q} | \mathbf{p}_{xf} \mathbf{p}_{yf}^T) - - Where : - - - :math:`\mathbf{X}`: Data matrix in the source space - - :math:`\mathbf{Y}`: Data matrix in the target space - - :math:`\mathbf{D^{(s)}}`: Sample matrix - - :math:`\mathbf{D^{(f)}}`: Feature matrix - - :math:`\mathbf{p}_{xs}`: distribution of the samples in the source space - - :math:`\mathbf{p}_{xf}`: distribution of the features in the source space - - :math:`\mathbf{p}_{ys}`: distribution of the samples in the target space - - :math:`\mathbf{p}_{yf}`: distribution of the features in the target space - - .. note:: This function allows epsilons to be zero. In that case, the EMD solver will be used. - - Parameters - ---------- - X : (sx, fx) array-like, float - First input matrix. - Y : (sy, fy) array-like, float - Second input matrix. - px : (sx, fx) tuple, float, optional (default = None) - Histogram assigned on rows (samples) and columns (features) of X. - Uniform distribution by default. - py : (sy, fy) tuple, float, optional (default = None) - Histogram assigned on rows (samples) and columns (features) of Y. - Uniform distribution by default. - eps : (scalar, scalar) tuple, float or int - Regularisation parameters for entropic approximation of sample and feature couplings. - Allow the case where eps contains 0. In that case, the EMD solver is used instead of - Sinkhorn solver. - alpha : (scalar, scalar) tuple, float or int, optional (default = (1,1)) - Interpolation parameter for fused CO-Optimal Transport w.r.t the sample and feature couplings. - D : tuple of matrices (sx, sy) and (fx, fy), float, optional (default = (None, None)) - Sample and feature matrices, in case of fused CO-Optimal Transport. - dict_init : dictionary, optional (default = None) - Dictionary containing 4 keys: - + "duals_sample" and "duals_feature" whose values are - tuples of 2 vectors of size (sx, sy) and (fx, fy). - Initialization of sample and feature dual vectors if using Sinkhorn algorithm. - Zero vectors by default. - + "pi_sample" and "pi_feature" whose values are matrices of size (sx, sy) and (fx, fy). - Initialization of sample and feature couplings. - Uniform distributions by default. - log : bool, optional (default = False) - If True then the cost and 4 dual vectors are recorded. - verbose : bool, optional (default = False) - If True then print the cost. - early_stopping_tol : float, optional (default = 1e-6) - Tolerance for the early stopping, if the absolute value of the difference between the - last 2 recorded costs is smaller than the tolerance, then stop training. - eval_bcd : int, optional (default = 1) - Multiplier of iteration at which the cost is calculated. For example, - if eval_bcd = 10, then the cost is calculated at iterations 10, 20, 30, etc... - tol_bcd : float, optional (default = 1e-7) - Tolerance of BCD scheme. - nits_bcd : int, optional (default = 100) - Number of BCD iterations. - tol_sinkhorn : float, optional (default = 1e-7) - Tolerance of Sinkhorn algorithm. - nits_sinkhorn : int, optional (default = 100) - Number of Sinkhorn iterations. - method_sinkhorn : string, optional (default = "sinkhorn") - Method used in POT's Sinkhorn solver. Only support "sinkhorn" and "sinkhorn_log". - - Returns - ------- - Entropic COOT : float - - If log is True, then also return the dictionary output of COOT. - - References - ---------- - .. [44] I. Redko, T. Vayer, R. Flamary, and N. Courty, CO-Optimal Transport, - Advances in Neural Information Processing Systems, 33 (2020). - """ - - pi_sample, pi_feature, dict_log = \ - co_optimal_transport(X, Y, px, py, eps, alpha, D, dict_init, True, verbose, early_stopping_tol, \ - eval_bcd, tol_bcd, nits_bcd, nits_sinkhorn, tol_sinkhorn, method_sinkhorn) - - X, Y = list_to_array(X, Y) - nx = get_backend(X, Y) - - sx, fx = X.shape - sy, fy = Y.shape - - px_samp, px_feat = px - py_samp, py_feat = py - - if px_samp is None: - px_samp = nx.ones(sx, type_as=X) / sx - if px_feat is None: - px_feat = nx.ones(fx, type_as=X) / fx - if py_samp is None: - py_samp = nx.ones(sy, type_as=Y) / sy - if py_feat is None: - py_feat = nx.ones(fy, type_as=Y) / fy - - vx_samp, vy_samp = dict_log["duals_sample"] - vx_feat, vy_feat = dict_log["duals_feature"] - - gradX = 2 * X * (px_samp[:,None] * px_feat[None,:]) - 2 * pi_sample @ Y @ pi_feature.T # shape (sx, fx) - gradY = 2 * Y * (py_samp[:,None] * py_feat[None,:]) - 2 * pi_sample.T @ X @ pi_feature # shape (sy, fy) - - ent_coot = dict_log["log_ent_cost"][-1] - ent_coot = nx.set_gradients(ent_coot, (px_samp, px_feat, py_samp, py_feat, X, Y), \ - (vx_samp, vx_feat, vy_samp, vy_feat, gradX, gradY)) - - if log: - return ent_coot, dict_log - - else: - return ent_coot \ No newline at end of file From 96ea795efca7f273f07e79f35d3cc719e4cbf773 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Wed, 22 Feb 2023 17:23:06 +0100 Subject: [PATCH 6/8] reformat with PEP8 --- ot/bregman.py | 103 +++++++++++++---------- test/test_bregman.py | 191 ++++++++++++++++++++++++++++--------------- 2 files changed, 186 insertions(+), 108 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 215ade0be..e0f6bd177 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -24,7 +24,7 @@ from .backend import get_backend -def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, +def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, stopThr=1e-9, verbose=False, log=False, warn=True, **kwargs): r""" @@ -157,27 +157,27 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, """ if method.lower() == 'sinkhorn': - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': - return sinkhorn_log(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'greenkhorn': - return greenkhorn(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + return greenkhorn(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_epsilon_scaling': return sinkhorn_epsilon_scaling(a, b, M, reg, - numItermax=numItermax, warmstart=warmstart, + numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) @@ -185,7 +185,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, raise ValueError("Unknown method '%s'." % method) -def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, +def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the loss @@ -328,17 +328,17 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, if len(b.shape) < 2: if method.lower() == 'sinkhorn': - res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': - res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) @@ -352,17 +352,17 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, else: if method.lower() == 'sinkhorn': - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': - return sinkhorn_log(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) @@ -370,7 +370,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, raise ValueError("Unknown method '%s'." % method) -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, verbose=False, log=False, warn=True, **kwargs): r""" @@ -1101,7 +1101,8 @@ def get_Gamma(alpha, beta, u, v): # remove numerical problems and store them in K if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: if n_hists: - alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v)) + alpha, beta = alpha + reg * \ + nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v)) else: alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v) if n_hists: @@ -1325,13 +1326,15 @@ def get_reg(n): # exponential decreasing # we can speed up the process by checking for the error only all # the 10th iterations transp = G - err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + nx.norm(nx.sum(transp, axis=1) - a) ** 2 + err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + \ + nx.norm(nx.sum(transp, axis=1) - a) ** 2 if log: log['err'].append(err) if verbose: if ii % (print_period * 10) == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) if err <= stopThr and ii > numItermin: @@ -1675,8 +1678,10 @@ def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_ini for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): M_i = dist(X, measure_locations_i) - T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, numItermax=numInnerItermax, **kwargs) - T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i) + T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, + numItermax=numInnerItermax, **kwargs) + T_sum = T_sum + weight_i * 1. / \ + b[:, None] * nx.dot(T_i, measure_locations_i) displacement_square_norm = nx.sum((T_sum - X) ** 2) if log: @@ -1685,7 +1690,8 @@ def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_ini X = T_sum if verbose: - print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm) + print('iteration %d, displacement_square_norm=%f\n', + iter_count, displacement_square_norm) iter_count += 1 @@ -2240,7 +2246,8 @@ def convol_imgs(imgs): if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) if err < stopThr: break @@ -2318,7 +2325,8 @@ def convol_img(log_img): if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) if err < stopThr: break @@ -2477,7 +2485,8 @@ def convol_imgs(imgs): if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) # debiased Sinkhorn does not converge monotonically @@ -2557,7 +2566,8 @@ def convol_img(log_img): if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) if err < stopThr and ii > 20: break @@ -3001,7 +3011,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', elif isinstance(batchSize, tuple) and len(batchSize) == 2: bs, bt = batchSize[0], batchSize[1] else: - raise ValueError("Batch size must be in integer or a tuple of two integers") + raise ValueError( + "Batch size must be in integer or a tuple of two integers") range_s, range_t = range(0, ns, bs), range(0, nt, bt) @@ -3039,7 +3050,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) M = nx.from_numpy(M, type_as=a) m1_cols.append( - nx.sum(nx.exp(f[i:i + bs, None] + g[None, :] - M / reg), axis=1) + nx.sum(nx.exp(f[i:i + bs, None] + + g[None, :] - M / reg), axis=1) ) m1 = nx.concatenate(m1_cols, axis=0) err = nx.sum(nx.abs(m1 - a)) @@ -3047,7 +3059,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', dict_log["err"].append(err) if verbose and (i_ot + 1) % 100 == 0: - print("Error in marginal at iteration {} = {}".format(i_ot + 1, err)) + print("Error in marginal at iteration {} = {}".format( + i_ot + 1, err)) if err <= stopThr: break @@ -3188,8 +3201,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if isLazy: if log: f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, - numIterMax=numIterMax, - warmstart=warmstart, + numIterMax=numIterMax, + warmstart=warmstart, stopThr=stopThr, isLazy=isLazy, batchSize=batchSize, @@ -3197,8 +3210,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', warn=warn) else: f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, - numIterMax=numIterMax, - warmstart=warmstart, + numIterMax=numIterMax, + warmstart=warmstart, stopThr=stopThr, isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log, @@ -3227,12 +3240,12 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', M = dist(X_s, X_t, metric=metric) if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart, + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) return sinkhorn_loss, log else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart, + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) return sinkhorn_loss @@ -3359,21 +3372,22 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli if log: sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, warmstart=warmstart, + numIterMax=numIterMax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, warmstart=warmstart_a, + numIterMax=numIterMax, warmstart=warmstart_a, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, warmstart=warmstart_b, + numIterMax=numIterMax, warmstart=warmstart_b, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) - sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) + sinkhorn_div = sinkhorn_loss_ab - 0.5 * \ + (sinkhorn_loss_a + sinkhorn_loss_b) log = {} log['sinkhorn_loss_ab'] = sinkhorn_loss_ab @@ -3387,24 +3401,25 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli else: sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, warmstart=warmstart, + numIterMax=numIterMax, warmstart=warmstart, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, warmstart=warmstart_a, + numIterMax=numIterMax, warmstart=warmstart_a, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, warmstart=warmstart_b, + numIterMax=numIterMax, warmstart=warmstart_b, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) - sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) + sinkhorn_div = sinkhorn_loss_ab - 0.5 * \ + (sinkhorn_loss_a + sinkhorn_loss_b) return nx.maximum(0, sinkhorn_div) @@ -3572,7 +3587,8 @@ def projection(u, epsilon): epsilon_u_square = a[0] / aK_sort[ns_budget - 1] else: aK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1], + bottleneck.partition(nx.to_numpy( + K_sum_cols), ns_budget - 1)[ns_budget - 1], type_as=M ) epsilon_u_square = a[0] / aK_sort @@ -3582,7 +3598,8 @@ def projection(u, epsilon): epsilon_v_square = b[0] / bK_sort[nt_budget - 1] else: bK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1], + bottleneck.partition(nx.to_numpy( + K_sum_rows), nt_budget - 1)[nt_budget - 1], type_as=M ) epsilon_v_square = b[0] / bK_sort diff --git a/test/test_bregman.py b/test/test_bregman.py index 9fb09798f..bd5b65570 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -59,10 +59,12 @@ def test_convergence_warning(method): with pytest.warns(UserWarning): ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) with pytest.warns(UserWarning): - ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=True) + ot.sinkhorn2(a1, a2, M, 1, method=method, + stopThr=0, numItermax=1, warn=True) with warnings.catch_warnings(): warnings.simplefilter("error") - ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=False) + ot.sinkhorn2(a1, a2, M, 1, method=method, + stopThr=0, numItermax=1, warn=False) def test_not_implemented_method(): @@ -266,12 +268,16 @@ def test_sinkhorn_variants(nx): ub, M_nx = nx.from_numpy(u, M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) Ges = nx.to_numpy(ot.sinkhorn( ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) - G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) + G_green = nx.to_numpy(ot.sinkhorn( + ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -371,9 +377,12 @@ def test_sinkhorn_variants_multi_b(nx): ub, bb, M_nx = nx.from_numpy(u, b, M) G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn( + ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn( + ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn( + ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -399,9 +408,12 @@ def test_sinkhorn2_variants_multi_b(nx): ub, bb, M_nx = nx.from_numpy(u, b, M) G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn2( + ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn2( + ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn2( + ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -419,12 +431,16 @@ def test_sinkhorn_variants_log(): M = ot.dist(x, x) - G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True) - Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) - Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', + stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn( + u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) + Gs, logs = ot.sinkhorn( + u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) Ges, loges = ot.sinkhorn( u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,) - G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) + G_green, loggreen = ot.sinkhorn( + u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) @@ -446,7 +462,8 @@ def test_sinkhorn_variants_log_multib(verbose, warn): M = ot.dist(x, x) - G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', + stopThr=1e-10, log=True) Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True, verbose=verbose, warn=warn) Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True, @@ -485,8 +502,10 @@ def test_barycenter(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method) else: # wasserstein - bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn) - bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass_np = ot.bregman.barycenter( + A, M, reg, weights, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter( + A_nx, M_nx, reg, weights_nx, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass)) @@ -514,7 +533,8 @@ def test_free_support_sinkhorn_barycenter(): # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization # term to 1, but this should be, in general, fine-tuned to the problem. - X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1) + X = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations, measures_weights, X_init, reg=1) # Verifies if calculated barycenter matches ground-truth np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) @@ -545,8 +565,10 @@ def test_barycenter_assymetric_cost(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, method=method) else: # wasserstein - bary_wass_np = ot.bregman.barycenter(A, M, reg, method=method, verbose=verbose, warn=warn) - bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True) + bary_wass_np = ot.bregman.barycenter( + A, M, reg, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter( + A_nx, M_nx, reg, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass)) @@ -581,17 +603,20 @@ def test_barycenter_debiased(nx, method, verbose, warn): reg = 1e-2 if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): - ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) + ot.bregman.barycenter_debiased( + A_nx, M_nx, reg, weights, method=method) else: bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, verbose=verbose, warn=warn) - bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass, _ = ot.bregman.barycenter_debiased( + A_nx, M_nx, reg, weights_nx, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) - ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False) + ot.bregman.barycenter_debiased( + A_nx, M_nx, reg, log=True, verbose=False) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) @@ -616,7 +641,8 @@ def test_convergence_warning_barycenters(method): weights = np.array([1 - alpha, alpha]) reg = 0.1 with pytest.warns(UserWarning): - ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1) + ot.bregman.barycenter_debiased( + A, M, reg, weights, method=method, numItermax=1) with pytest.warns(UserWarning): ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1) with pytest.warns(UserWarning): @@ -648,7 +674,8 @@ def test_barycenter_stabilization(nx): # wasserstein reg = 1e-2 - bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) + bar_np = ot.bregman.barycenter( + A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) bar_stable = nx.to_numpy(ot.bregman.barycenter( A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized", stopThr=1e-8, verbose=True @@ -683,8 +710,10 @@ def test_wasserstein_bary_2d(nx, method): with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) else: - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True) + bary_wass = nx.to_numpy( + ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) @@ -713,10 +742,13 @@ def test_wasserstein_bary_2d_debiased(nx, method): reg = 1e-2 if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) + ot.bregman.convolutional_barycenter2d_debiased( + A_nx, reg, method=method) else: - bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True) + bary_wass = nx.to_numpy( + ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) @@ -750,7 +782,8 @@ def test_unmix(nx): # wasserstein reg = 1e-3 um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) - um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) + um = nx.to_numpy(ot.bregman.unmix( + ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) @@ -781,10 +814,12 @@ def test_empirical_sinkhorn(nx): sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean')) + G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn( + X_sb, X_tb, 1, metric='euclidean')) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) + loss_emp_sinkhorn = nx.to_numpy( + ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints @@ -817,23 +852,27 @@ def test_lazy_empirical_sinkhorn(nx): ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) - f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) + f, g = ot.bregman.empirical_sinkhorn( + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) - f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + f, g, log_es = ot.bregman.empirical_sinkhorn( + X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g = ot.bregman.empirical_sinkhorn( + X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) f, g = nx.to_numpy(f), nx.to_numpy(g) G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2( + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) @@ -865,22 +904,27 @@ def test_empirical_sinkhorn_divergence(nx): M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t) + ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy( + a, b, X_s, X_t, M, M_s, M_t) - emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) + emp_sinkhorn_div = nx.to_numpy( + ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) sinkhorn_div = nx.to_numpy( ot.sinkhorn2(ab, bb, M_nx, 1) - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) ) - emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b) + emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence( + X_s, X_t, 1, a=a, b=b) # check constraints - np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) + np.testing.assert_allclose( + emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) np.testing.assert_allclose( emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn - ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True) + ot.bregman.empirical_sinkhorn_divergence( + X_sb, X_tb, 1, a=ab, b=bb, log=True) @pytest.mark.skipif(not torch, reason="No torch available") @@ -902,7 +946,8 @@ def test_empirical_sinkhorn_divergence_gradient(): X_sb.requires_grad = True X_tb.requires_grad = True - emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb) + emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence( + X_sb, X_tb, 1, a=ab, b=bb) emp_sinkhorn_div.backward() @@ -931,7 +976,8 @@ def test_stabilized_vs_sinkhorn_multidim(nx): ab, bb, M_nx = nx.from_numpy(a, b, M) - G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) + G_np, _ = ot.bregman.sinkhorn( + a, b, M, reg=epsilon, method="sinkhorn", log=True) G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, method="sinkhorn_stabilized", log=True) @@ -996,7 +1042,8 @@ def test_screenkhorn(nx): # sinkhorn G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) + G_screen = nx.to_numpy(ot.bregman.screenkhorn( + ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) # check marginals np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) @@ -1014,6 +1061,7 @@ def test_convolutional_barycenter_non_square(nx): np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) np.testing.assert_allclose(b, b_np) + def test_sinkhorn_warmstart(): m, n = 10, 20 a = ot.unif(m) @@ -1021,7 +1069,7 @@ def test_sinkhorn_warmstart(): Xs = np.arange(m) * 1.0 Xt = np.arange(n) * 1.0 - M = ot.dist(Xs.reshape(-1,1), Xt.reshape(-1,1)) + M = ot.dist(Xs.reshape(-1, 1), Xt.reshape(-1, 1)) # Generate warmstart from dual vectors of unregularized OT _, log = ot.lp.emd(a, b, M, log=True) @@ -1030,22 +1078,28 @@ def test_sinkhorn_warmstart(): reg = 1 # Optimal plan with uniform warmstart - pi_unif, _ = ot.bregman.sinkhorn(a, b, M, reg, method="sinkhorn", warmstart=None, log=True) + pi_unif, _ = ot.bregman.sinkhorn( + a, b, M, reg, method="sinkhorn", warmstart=None, log=True) # Optimal plan with warmstart generated from unregularized OT - pi_sh, _ = ot.bregman.sinkhorn(a, b, M, reg, method="sinkhorn", warmstart=warmstart, log=True) - pi_sh_log, _ = ot.bregman.sinkhorn(a, b, M, reg, method="sinkhorn_log", warmstart=warmstart, log=True) - pi_sh_stab, _ = ot.bregman.sinkhorn(a, b, M, reg, method="sinkhorn_stabilized", warmstart=warmstart, log=True) - pi_sh_sc, _ = ot.bregman.sinkhorn(a, b, M, reg, method="sinkhorn_epsilon_scaling", warmstart=warmstart, log=True) + pi_sh, _ = ot.bregman.sinkhorn( + a, b, M, reg, method="sinkhorn", warmstart=warmstart, log=True) + pi_sh_log, _ = ot.bregman.sinkhorn( + a, b, M, reg, method="sinkhorn_log", warmstart=warmstart, log=True) + pi_sh_stab, _ = ot.bregman.sinkhorn( + a, b, M, reg, method="sinkhorn_stabilized", warmstart=warmstart, log=True) + pi_sh_sc, _ = ot.bregman.sinkhorn( + a, b, M, reg, method="sinkhorn_epsilon_scaling", warmstart=warmstart, log=True) np.testing.assert_allclose(pi_unif, pi_sh, atol=1e-05) np.testing.assert_allclose(pi_unif, pi_sh_log, atol=1e-05) np.testing.assert_allclose(pi_unif, pi_sh_stab, atol=1e-05) np.testing.assert_allclose(pi_unif, pi_sh_sc, atol=1e-05) + def test_empirical_sinkhorn_warmstart(): m, n = 10, 20 - Xs = np.arange(m).reshape(-1,1) * 1.0 - Xt = np.arange(n).reshape(-1,1) * 1.0 + Xs = np.arange(m).reshape(-1, 1) * 1.0 + Xt = np.arange(n).reshape(-1, 1) * 1.0 M = ot.dist(Xs, Xt) # Generate warmstart from dual vectors of unregularized OT @@ -1057,22 +1111,26 @@ def test_empirical_sinkhorn_warmstart(): reg = 1 # Optimal plan with uniform warmstart - f, g, _ = ot.bregman.empirical_sinkhorn(X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=None, log=True) + f, g, _ = ot.bregman.empirical_sinkhorn( + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=None, log=True) pi_unif = np.exp(f[:, None] + g[None, :] - M / reg) # Optimal plan with warmstart generated from unregularized OT - f, g, _ = ot.bregman.empirical_sinkhorn(X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=warmstart, log=True) + f, g, _ = ot.bregman.empirical_sinkhorn( + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=warmstart, log=True) pi_ws_lazy = np.exp(f[:, None] + g[None, :] - M / reg) - pi_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn(X_s=Xs, X_t=Xt, reg=reg, isLazy=False, warmstart=warmstart, log=True) + pi_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn( + X_s=Xs, X_t=Xt, reg=reg, isLazy=False, warmstart=warmstart, log=True) np.testing.assert_allclose(pi_unif, pi_ws_lazy, atol=1e-05) np.testing.assert_allclose(pi_unif, pi_ws_not_lazy, atol=1e-05) + def test_empirical_sinkhorn_divergence_warmstart(): m, n = 10, 20 - Xs = np.arange(m).reshape(-1,1) * 1.0 - Xt = np.arange(n).reshape(-1,1) * 1.0 + Xs = np.arange(m).reshape(-1, 1) * 1.0 + Xt = np.arange(n).reshape(-1, 1) * 1.0 M = ot.dist(Xs, Xt) - + # Generate warmstart from dual vectors of unregularized OT a = ot.unif(m) b = ot.unif(n) @@ -1082,10 +1140,13 @@ def test_empirical_sinkhorn_divergence_warmstart(): reg = 1 # Optimal plan with uniform warmstart - sd_unif, _ = ot.bregman.empirical_sinkhorn_divergence(X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=None, log=True) + sd_unif, _ = ot.bregman.empirical_sinkhorn_divergence( + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=None, log=True) # Optimal plan with warmstart generated from unregularized OT - sd_ws_lazy, _ = ot.bregman.empirical_sinkhorn_divergence(X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=warmstart, log=True) - sd_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn_divergence(X_s=Xs, X_t=Xt, reg=reg, isLazy=False, warmstart=warmstart, log=True) + sd_ws_lazy, _ = ot.bregman.empirical_sinkhorn_divergence( + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=warmstart, log=True) + sd_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn_divergence( + X_s=Xs, X_t=Xt, reg=reg, isLazy=False, warmstart=warmstart, log=True) np.testing.assert_allclose(sd_unif, sd_ws_lazy, atol=1e-05) - np.testing.assert_allclose(sd_unif, sd_ws_not_lazy, atol=1e-05) \ No newline at end of file + np.testing.assert_allclose(sd_unif, sd_ws_not_lazy, atol=1e-05) From 96449a0105299efb1ce676e52b5397c13965ca50 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Thu, 23 Feb 2023 09:50:34 +0100 Subject: [PATCH 7/8] Fix W291 trailing whitespace error in pep8 test --- ot/bregman.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index e0f6bd177..2dd465acd 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -94,7 +94,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, numItermax : int, optional Max number of iterations warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, + Initialization of dual vectors. If provided, + the dual vectors must be already taken the logarithm, i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) @@ -256,7 +257,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, numItermax : int, optional Max number of iterations warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, + Initialization of dual vectors. If provided, + the dual vectors must be already taken the logarithm, i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) @@ -414,7 +416,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, numItermax : int, optional Max number of iterations warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, + Initialization of dual vectors. If provided, + the dual vectors must be already taken the logarithm, i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) @@ -601,7 +604,8 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, ve numItermax : int, optional Max number of iterations warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, + Initialization of dual vectors. If provided, + the dual vectors must be already taken the logarithm, i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) @@ -811,7 +815,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, warmstart=None, stopThr=1e-9, ver numItermax : int, optional Max number of iterations warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, + Initialization of dual vectors. If provided, + the dual vectors must be already taken the logarithm, i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) @@ -2934,7 +2939,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numItermax : int, optional Max number of iterations warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, + Initialization of dual vectors. If provided, + the dual vectors must be already taken the logarithm, i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) @@ -3133,7 +3139,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numItermax : int, optional Max number of iterations warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, + Initialization of dual vectors. If provided, + the dual vectors must be already taken the logarithm, i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) @@ -3324,7 +3331,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli numItermax : int, optional Max number of iterations warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm, + Initialization of dual vectors. If provided, + the dual vectors must be already taken the logarithm, i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) From 2571802624e78504a291d6dbcb8686e4c84711f2 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Thu, 23 Feb 2023 11:56:28 +0100 Subject: [PATCH 8/8] Rearange position of warmstart argument and edit its description --- RELEASES.md | 86 +++++++++--------- ot/bregman.py | 212 ++++++++++++++++++++----------------------- test/test_bregman.py | 22 ++--- 3 files changed, 152 insertions(+), 168 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 68e2eacf5..292d1df21 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -13,7 +13,7 @@ - Added Free Support Sinkhorn Barycenter + example (PR #387) - New API for OT solver using function `ot.solve` (PR #388) - Backend version of `ot.partial` and `ot.smooth` (PR #388) -- Added argument for warmstart of dual vectors in Sinkhorn-based methods in `ot.bregman` (PR #437) +- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437) #### Closed issues @@ -35,10 +35,10 @@ roughly 2^31) (PR #381) - Fixed an issue where the doc could not be built due to some changes in matplotlib's API (Issue #403, PR #402) - Replaced Numpy C Compiler with Setuptools C Compiler due to deprecation issues (Issue #408, PR #409) - Fixed weak optimal transport docstring (Issue #404, PR #410) -- Fixed error with parameter `log=True`for `SinkhornLpl1Transport` (Issue #412, +- Fixed error with parameter `log=True`for `SinkhornLpl1Transport` (Issue #412, PR #413) - Fixed an issue about `warn` parameter in `sinkhorn2` (PR #417) -- Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls +- Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls that explicitly specified `stopThr=1e-9` (Issue #421, PR #422). - Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425) @@ -88,7 +88,7 @@ and [Factored coupling OT](https://pythonot.github.io/auto_examples/others/plot_ - Remove deprecated `ot.gpu` submodule (PR #361) - Update examples in the gallery (PR #359) -- Add stochastic loss and OT plan computation for regularized OT and +- Add stochastic loss and OT plan computation for regularized OT and backend examples(PR #360) - Implementation of factored OT with emd and sinkhorn (PR #358) - A brand new logo for POT (PR #357) @@ -104,9 +104,9 @@ and [Factored coupling OT](https://pythonot.github.io/auto_examples/others/plot_ #### Closed issues -- Fix mass gradient of `ot.emd2` and `ot.gromov_wasserstein2` so that they are +- Fix mass gradient of `ot.emd2` and `ot.gromov_wasserstein2` so that they are centered (Issue #364, PR #363) -- Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337, +- Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337, PR #338) - Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349) - Warning when feeding integer cost matrix to EMD solver resulting in an integer transport plan (Issue #345, PR #343) @@ -156,21 +156,21 @@ As always we want to that the contributors who helped make POT better (and bug f - Fix bug in older Numpy ABI (<1.20) (Issue #308, PR #326) - Fix bug in `ot.dist` function when non euclidean distance (Issue #305, PR #306) -- Fix gradient scaling for functions using `nx.set_gradients` (Issue #309, +- Fix gradient scaling for functions using `nx.set_gradients` (Issue #309, PR #310) -- Fix bug in generalized Conditional gradient solver and SinkhornL1L2 +- Fix bug in generalized Conditional gradient solver and SinkhornL1L2 (Issue #311, PR #313) - Fix log error in `gromov_barycenters` (Issue #317, PR #3018) ## 0.8.0 *November 2021* -This new stable release introduces several important features. +This new stable release introduces several important features. First we now have an OpenMP compatible exact ot solver in `ot.emd`. The OpenMP version is used when the parameter `numThreads` is greater than one and can lead to nice -speedups on multi-core machines. +speedups on multi-core machines. Second we have introduced a backend mechanism that allows to use standard POT function seamlessly on Numpy, Pytorch and Jax arrays. Other backends are coming @@ -189,7 +189,7 @@ for a [sliced Wasserstein gradient flow](https://PythonOT.github.io/auto_examples/backends/plot_sliced_wass_grad_flow_pytorch.html) and [optimizing the Gromov-Wassersein distance](https://PythonOT.github.io/auto_examples/backends/plot_optim_gromov_pytorch.html). Note that the Jax backend is still in early development and quite slow at the moment, we strongly recommend for Jax users to use the [OTT -toolbox](https://github.com/google-research/ott) when possible. +toolbox](https://github.com/google-research/ott) when possible. As a result of this new feature, the old `ot.gpu` submodule is now deprecated since GPU implementations can be done using GPU arrays on the torch backends. @@ -212,7 +212,7 @@ Finally POT was accepted for publication in the Journal of Machine Learning Research (JMLR) open source software track and we ask the POT users to cite [this paper](https://www.jmlr.org/papers/v22/20-451.html) from now on. The documentation has been improved in particular by adding a "Why OT?" section to the quick start guide and several new examples illustrating -the new features. The documentation now has two version : the stable version +the new features. The documentation now has two version : the stable version [https://pythonot.github.io/](https://pythonot.github.io/) corresponding to the last release and the master version [https://pythonot.github.io/master](https://pythonot.github.io/master) that corresponds to the current master branch on GitHub. @@ -222,7 +222,7 @@ As usual, we want to thank all the POT contributors (now 37 people have contributed to the toolbox). But for this release we thank in particular Nathan Cassereau and Kamel Guerda from the AI support team at [IDRIS](http://www.idris.fr/) for their support to the development of the -backend and OpenMP implementations. +backend and OpenMP implementations. #### New features @@ -289,7 +289,7 @@ repository for the new documentation is now hosted at This is the first release where the Python 2.7 tests have been removed. Most of the toolbox should still work but we do not offer support for Python 2.7 and -will close related Issues. +will close related Issues. A lot of changes have been done to the documentation that is now hosted on [https://PythonOT.github.io/](https://PythonOT.github.io/) instead of @@ -322,7 +322,7 @@ problems. This release is also the moment to thank all the POT contributors (old and new) for helping making POT such a nice toolbox. A lot of changes (also in the API) -are coming for the next versions. +are coming for the next versions. #### Features @@ -351,14 +351,14 @@ are coming for the next versions. - Log bugs for Gromov-Wassertein solver (Issue #107, fixed in PR #108) - Weight issues in barycenter function (PR #106) -## 0.6.0 +## 0.6.0 *July 2019* -This is the first official stable release of POT and this means a jump to 0.6! +This is the first official stable release of POT and this means a jump to 0.6! The library has been used in the wild for a while now and we have reached a state where a lot of fundamental OT solvers are available and tested. It has been quite stable in the last months -but kept the beta flag in its Pypi classifiers until now. +but kept the beta flag in its Pypi classifiers until now. Note that this release will be the last one supporting officially Python 2.7 (See https://python3statement.org/ for more reasons). For next release we will keep @@ -387,7 +387,7 @@ graphs](https://github.com/rflamary/POT/blob/master/notebooks/plot_barycenter_fg A lot of work has been done on the documentation with several new examples corresponding to the new features and a lot of corrections for the -docstrings. But the most visible change is a new +docstrings. But the most visible change is a new [quick start guide](https://pot.readthedocs.io/en/latest/quickstart.html) for POT that gives several pointers about which function or classes allow to solve which specific OT problem. When possible a link is provided to relevant examples. @@ -425,29 +425,29 @@ bring new features and solvers to the library. - Issue #72 Macosx build problem -## 0.5.0 +## 0.5.0 *Sep 2018* -POT is 2 years old! This release brings numerous new features to the +POT is 2 years old! This release brings numerous new features to the toolbox as listed below but also several bug correction. -Among the new features, we can highlight a [non-regularized Gromov-Wasserstein -solver](https://github.com/rflamary/POT/blob/master/notebooks/plot_gromov.ipynb), -a new [greedy variant of sinkhorn](https://pot.readthedocs.io/en/latest/all.html#ot.bregman.greenkhorn), -[non-regularized](https://pot.readthedocs.io/en/latest/all.html#ot.lp.barycenter), +Among the new features, we can highlight a [non-regularized Gromov-Wasserstein +solver](https://github.com/rflamary/POT/blob/master/notebooks/plot_gromov.ipynb), +a new [greedy variant of sinkhorn](https://pot.readthedocs.io/en/latest/all.html#ot.bregman.greenkhorn), +[non-regularized](https://pot.readthedocs.io/en/latest/all.html#ot.lp.barycenter), [convolutional (2D)](https://github.com/rflamary/POT/blob/master/notebooks/plot_convolutional_barycenter.ipynb) and [free support](https://github.com/rflamary/POT/blob/master/notebooks/plot_free_support_barycenter.ipynb) - Wasserstein barycenters and [smooth](https://github.com/rflamary/POT/blob/prV0.5/notebooks/plot_OT_1D_smooth.ipynb) - and [stochastic](https://pot.readthedocs.io/en/latest/all.html#ot.stochastic.sgd_entropic_regularization) + Wasserstein barycenters and [smooth](https://github.com/rflamary/POT/blob/prV0.5/notebooks/plot_OT_1D_smooth.ipynb) + and [stochastic](https://pot.readthedocs.io/en/latest/all.html#ot.stochastic.sgd_entropic_regularization) implementation of entropic OT. -POT 0.5 also comes with a rewriting of ot.gpu using the cupy framework instead of -the unmaintained cudamat. Note that while we tried to keed changes to the -minimum, the OTDA classes were deprecated. If you are happy with the cudamat +POT 0.5 also comes with a rewriting of ot.gpu using the cupy framework instead of +the unmaintained cudamat. Note that while we tried to keed changes to the +minimum, the OTDA classes were deprecated. If you are happy with the cudamat implementation, we recommend you stay with stable release 0.4 for now. -The code quality has also improved with 92% code coverage in tests that is now -printed to the log in the Travis builds. The documentation has also been +The code quality has also improved with 92% code coverage in tests that is now +printed to the log in the Travis builds. The documentation has also been greatly improved with new modules and examples/notebooks. This new release is so full of new stuff and corrections thanks to the old @@ -466,24 +466,24 @@ and new POT contributors (you can see the list in the [readme](https://github.co * Stochastic OT in the dual and semi-dual (PR #52 and PR #62) * Free support barycenters (PR #56) * Speed-up Sinkhorn function (PR #57 and PR #58) -* Add convolutional Wassersein barycenters for 2D images (PR #64) +* Add convolutional Wassersein barycenters for 2D images (PR #64) * Add Greedy Sinkhorn variant (Greenkhorn) (PR #66) * Big ot.gpu update with cupy implementation (instead of un-maintained cudamat) (PR #67) #### Deprecation -Deprecated OTDA Classes were removed from ot.da and ot.gpu for version 0.5 -(PR #48 and PR #67). The deprecation message has been for a year here since +Deprecated OTDA Classes were removed from ot.da and ot.gpu for version 0.5 +(PR #48 and PR #67). The deprecation message has been for a year here since 0.4 and it is time to pull the plug. #### Closed issues * Issue #35 : remove import plot from ot/__init__.py (See PR #41) * Issue #43 : Unusable parameter log for EMDTransport (See PR #44) -* Issue #55 : UnicodeDecodeError: 'ascii' while installing with pip +* Issue #55 : UnicodeDecodeError: 'ascii' while installing with pip -## 0.4 +## 0.4 *15 Sep 2017* This release contains a lot of contribution from new contributors. @@ -493,14 +493,14 @@ This release contains a lot of contribution from new contributors. * Automatic notebooks and doc update (PR #27) * Add gromov Wasserstein solver and Gromov Barycenters (PR #23) -* emd and emd2 can now return dual variables and have max_iter (PR #29 and PR #25) +* emd and emd2 can now return dual variables and have max_iter (PR #29 and PR #25) * New domain adaptation classes compatible with scikit-learn (PR #22) * Proper tests with pytest on travis (PR #19) * PEP 8 tests (PR #13) #### Closed issues -* emd convergence problem du to fixed max iterations (#24) +* emd convergence problem du to fixed max iterations (#24) * Semi supervised DA error (#26) ## 0.3.1 @@ -508,7 +508,7 @@ This release contains a lot of contribution from new contributors. * Correct bug in emd on windows -## 0.3 +## 0.3 *7 Jul 2017* * emd* and sinkhorn* are now performed in parallel for multiple target distributions @@ -521,7 +521,7 @@ This release contains a lot of contribution from new contributors. * GPU implementations for sinkhorn and group lasso regularization -## V0.2 +## V0.2 *7 Apr 2017* * New dimensionality reduction method (WDA) @@ -529,7 +529,7 @@ This release contains a lot of contribution from new contributors. -## 0.1.11 +## 0.1.11 *5 Jan 2017* * Add sphinx gallery for better documentation @@ -537,7 +537,7 @@ This release contains a lot of contribution from new contributors. * Add simple tic() toc() functions for timing -## 0.1.10 +## 0.1.10 *7 Nov 2016* * numerical stabilization for sinkhorn (log domain and epsilon scaling) diff --git a/ot/bregman.py b/ot/bregman.py index 2dd465acd..192a9e2b0 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -24,9 +24,8 @@ from .backend import get_backend -def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, - stopThr=1e-9, verbose=False, log=False, warn=True, - **kwargs): +def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, + verbose=False, log=False, warn=True, warmstart=None, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -93,10 +92,6 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, those function for specific parameters numItermax : int, optional Max number of iterations - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, - the dual vectors must be already taken the logarithm, - i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional @@ -105,6 +100,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -158,36 +156,35 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, """ if method.lower() == 'sinkhorn': - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, - warn=warn, + warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'sinkhorn_log': - return sinkhorn_log(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, - warn=warn, + warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'greenkhorn': - return greenkhorn(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + return greenkhorn(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, - warn=warn) + warn=warn, warmstart=warmstart) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_epsilon_scaling': - return sinkhorn_epsilon_scaling(a, b, M, reg, - numItermax=numItermax, warmstart=warmstart, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) -def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, - stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs): +def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, + stopThr=1e-9, verbose=False, log=False, warn=False, warmstart=None, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the loss @@ -256,10 +253,6 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, 'sinkhorn_stabilized', see those function for specific parameters numItermax : int, optional Max number of iterations - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, - the dual vectors must be already taken the logarithm, - i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional @@ -268,6 +261,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -330,19 +326,19 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, if len(b.shape) < 2: if method.lower() == 'sinkhorn': - res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + log=log, warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'sinkhorn_log': - res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + log=log, warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) @@ -354,27 +350,26 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None, else: if method.lower() == 'sinkhorn': - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + log=log, warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'sinkhorn_log': - return sinkhorn_log(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + log=log, warn=warn, warmstart=warmstart, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, warmstart=warmstart, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, - verbose=False, log=False, warn=True, - **kwargs): +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, + verbose=False, log=False, warn=True, warmstart=None, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -415,10 +410,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, Regularization term >0 numItermax : int, optional Max number of iterations - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, - the dual vectors must be already taken the logarithm, - i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional @@ -427,6 +418,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -561,8 +555,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) -def sinkhorn_log(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, verbose=False, - log=False, warn=True, **kwargs): +def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, + log=False, warn=True, warmstart=None, **kwargs): r""" Solve the entropic regularization optimal transport problem in log space and return the OT matrix @@ -603,10 +597,6 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, ve Regularization term >0 numItermax : int, optional Max number of iterations - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, - the dual vectors must be already taken the logarithm, - i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional @@ -615,6 +605,9 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, ve record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -686,8 +679,8 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, ve lst_v = [] for k in range(n_hists): - res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, warmstart=warmstart[k], - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, stopThr=stopThr, + verbose=verbose, log=log, warmstart=warmstart[k], **kwargs) if log: lst_loss.append(nx.sum(M * res[0])) @@ -771,8 +764,8 @@ def get_logT(u, v): return nx.exp(get_logT(u, v)) -def greenkhorn(a, b, M, reg, numItermax=10000, warmstart=None, stopThr=1e-9, verbose=False, - log=False, warn=True): +def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, + log=False, warn=True, warmstart=None): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -814,16 +807,15 @@ def greenkhorn(a, b, M, reg, numItermax=10000, warmstart=None, stopThr=1e-9, ver Regularization term >0 numItermax : int, optional Max number of iterations - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, - the dual vectors must be already taken the logarithm, - i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) log : bool, optional record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -2899,8 +2891,8 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, warmstart=None, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, - log=False, warn=True, **kwargs): + numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, + log=False, warn=True, warmstart=None, **kwargs): r''' Solve the entropic regularization optimal transport problem and return the OT matrix from empirical data @@ -2938,10 +2930,6 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', samples weights in the target domain numItermax : int, optional Max number of iterations - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, - the dual vectors must be already taken the logarithm, - i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) isLazy: boolean, optional @@ -2957,6 +2945,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns @@ -3085,18 +3076,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', else: M = dist(X_s, X_t, metric=metric) if log: - pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart, stopThr=stopThr, - verbose=verbose, log=True, **kwargs) + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + verbose=verbose, log=True, warmstart=warmstart, **kwargs) return pi, log else: - pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart, stopThr=stopThr, - verbose=verbose, log=False, **kwargs) + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + verbose=verbose, log=False, warmstart=warmstart, **kwargs) return pi def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, warmstart=None, stopThr=1e-9, isLazy=False, - batchSize=100, verbose=False, log=False, warn=True, **kwargs): + numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, + verbose=False, log=False, warn=True, warmstart=None, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -3138,10 +3129,6 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', samples weights in the target domain numItermax : int, optional Max number of iterations - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, - the dual vectors must be already taken the logarithm, - i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) isLazy: boolean, optional @@ -3157,7 +3144,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. - + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -3209,20 +3198,20 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if log: f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, - warmstart=warmstart, stopThr=stopThr, isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log, - warn=warn) + warn=warn, + warmstart=warmstart) else: f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, - warmstart=warmstart, stopThr=stopThr, isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log, - warn=warn) + warn=warn, + warmstart=warmstart) bs = batchSize if isinstance(batchSize, int) else batchSize[0] range_s = range(0, ns, bs) @@ -3247,21 +3236,20 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', M = dist(X_s, X_t, metric=metric) if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart, + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - warn=warn, **kwargs) + warn=warn, warmstart=warmstart, **kwargs) return sinkhorn_loss, log else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart, + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - warn=warn, **kwargs) + warn=warn, warmstart=warmstart, **kwargs) return sinkhorn_loss def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, warmstart=None, stopThr=1e-9, - verbose=False, log=False, warn=True, - **kwargs): + numIterMax=10000, stopThr=1e-9, verbose=False, + log=False, warn=True, warmstart=None, **kwargs): r''' Compute the sinkhorn divergence loss from empirical data @@ -3330,10 +3318,6 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli samples weights in the target domain numItermax : int, optional Max number of iterations - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual vectors. If provided, - the dual vectors must be already taken the logarithm, - i.e. warmstart = (log_u, log_v), but not (u, v). stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional @@ -3342,6 +3326,9 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) Returns ------- @@ -3380,19 +3367,19 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli if log: sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, warmstart=warmstart, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, **kwargs) + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart, **kwargs) sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, warmstart=warmstart_a, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, **kwargs) + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_a, **kwargs) sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, warmstart=warmstart_b, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, **kwargs) + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_b, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * \ (sinkhorn_loss_a + sinkhorn_loss_b) @@ -3409,22 +3396,19 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli else: sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, warmstart=warmstart, - stopThr=stopThr, - verbose=verbose, log=log, - warn=warn, **kwargs) + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart, **kwargs) sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, warmstart=warmstart_a, - stopThr=stopThr, - verbose=verbose, log=log, - warn=warn, **kwargs) + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_a, **kwargs) sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, warmstart=warmstart_b, - stopThr=stopThr, - verbose=verbose, log=log, - warn=warn, **kwargs) + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_b, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * \ (sinkhorn_loss_a + sinkhorn_loss_b) diff --git a/test/test_bregman.py b/test/test_bregman.py index bd5b65570..f01bb144f 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -1079,16 +1079,16 @@ def test_sinkhorn_warmstart(): # Optimal plan with uniform warmstart pi_unif, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn", warmstart=None, log=True) + a, b, M, reg, method="sinkhorn", log=True, warmstart=None) # Optimal plan with warmstart generated from unregularized OT pi_sh, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn", warmstart=warmstart, log=True) + a, b, M, reg, method="sinkhorn", log=True, warmstart=warmstart) pi_sh_log, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn_log", warmstart=warmstart, log=True) + a, b, M, reg, method="sinkhorn_log", log=True, warmstart=warmstart) pi_sh_stab, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn_stabilized", warmstart=warmstart, log=True) + a, b, M, reg, method="sinkhorn_stabilized", log=True, warmstart=warmstart) pi_sh_sc, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn_epsilon_scaling", warmstart=warmstart, log=True) + a, b, M, reg, method="sinkhorn_epsilon_scaling", log=True, warmstart=warmstart) np.testing.assert_allclose(pi_unif, pi_sh, atol=1e-05) np.testing.assert_allclose(pi_unif, pi_sh_log, atol=1e-05) @@ -1112,14 +1112,14 @@ def test_empirical_sinkhorn_warmstart(): # Optimal plan with uniform warmstart f, g, _ = ot.bregman.empirical_sinkhorn( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=None, log=True) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None) pi_unif = np.exp(f[:, None] + g[None, :] - M / reg) # Optimal plan with warmstart generated from unregularized OT f, g, _ = ot.bregman.empirical_sinkhorn( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=warmstart, log=True) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart) pi_ws_lazy = np.exp(f[:, None] + g[None, :] - M / reg) pi_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn( - X_s=Xs, X_t=Xt, reg=reg, isLazy=False, warmstart=warmstart, log=True) + X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart) np.testing.assert_allclose(pi_unif, pi_ws_lazy, atol=1e-05) np.testing.assert_allclose(pi_unif, pi_ws_not_lazy, atol=1e-05) @@ -1141,12 +1141,12 @@ def test_empirical_sinkhorn_divergence_warmstart(): # Optimal plan with uniform warmstart sd_unif, _ = ot.bregman.empirical_sinkhorn_divergence( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=None, log=True) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None) # Optimal plan with warmstart generated from unregularized OT sd_ws_lazy, _ = ot.bregman.empirical_sinkhorn_divergence( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, warmstart=warmstart, log=True) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart) sd_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn_divergence( - X_s=Xs, X_t=Xt, reg=reg, isLazy=False, warmstart=warmstart, log=True) + X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart) np.testing.assert_allclose(sd_unif, sd_ws_lazy, atol=1e-05) np.testing.assert_allclose(sd_unif, sd_ws_not_lazy, atol=1e-05)