From 631c64266320428f553fd017266b22d46208201b Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 26 Oct 2021 20:13:38 +0200 Subject: [PATCH 01/25] add debiased sinkhorn barycenter + make loops pythonic --- ot/bregman.py | 533 +++++++++++++++++++++++++++++++++++++------------- 1 file changed, 399 insertions(+), 134 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index b59ee1b4c..278ccec00 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -383,9 +383,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, K = nx.exp(M / (-reg)) Kp = (1 / a).reshape(-1, 1) * K - cpt = 0 + err = 1 - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): uprev = u vprev = v KtransposeU = nx.dot(K.T, u) @@ -397,11 +397,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) + print('Warning: numerical errors at iteration', ii) u = uprev v = vprev break - if cpt % 10 == 0: + if ii % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: @@ -413,12 +413,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt + 1 + print('{:5d}|{:8e}|'.format(ii, err)) + if log: log['u'] = u log['v'] = v @@ -692,7 +694,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, dim_a = len(a) dim_b = len(b) - cpt = 0 + if log: log = {'err': []} @@ -722,11 +724,9 @@ def get_Gamma(alpha, beta, u, v): # print(np.min(K)) K = get_K(alpha, beta) - transp = K - loop = 1 - cpt = 0 + transp = K err = 1 - while loop: + for ii in range(numItermax): uprev = u vprev = v @@ -749,7 +749,7 @@ def get_Gamma(alpha, beta, u, v): v = nx.ones(dim_b, type_as=M) / dim_b K = get_K(alpha, beta) - if cpt % print_period == 0: + if ii % print_period == 0: # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: @@ -764,28 +764,25 @@ def get_Gamma(alpha, beta, u, v): if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % (print_period * 20) == 0: + if ii % (print_period * 20) == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) if err <= stopThr: - loop = False - - if cpt >= numItermax: - loop = False + break if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) + print('Warning: numerical errors at iteration', ii) u = uprev v = vprev break - cpt = cpt + 1 - if log: if n_hists: alpha = alpha[:, None] @@ -824,29 +821,19 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, r""" Solve the entropic regularization optimal transport problem with log stabilization and epsilon scaling. - The function solves the following optimization problem: - .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - s.t. \ \gamma 1 = a - \gamma^T 1= b - \gamma\geq 0 where : - - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) - - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` but with the log stabilization proposed in :ref:`[10] ` and the log scaling proposed in :ref:`[9] ` algorithm 3.2 - - Parameters ---------- a : array-like, shape (dim_a,) @@ -873,17 +860,14 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, Print information along iterations log : bool, optional record log if True - Returns ------- gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters - Examples -------- - >>> import ot >>> a=[.5, .5] >>> b=[.5, .5] @@ -891,23 +875,16 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, >>> ot.bregman.sinkhorn_epsilon_scaling(a, b, M, 1) array([[0.36552929, 0.13447071], [0.13447071, 0.36552929]]) - .. _references-sinkhorn-epsilon-scaling: References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - - See Also -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT - """ a, b, M = list_to_array(a, b, M) @@ -927,7 +904,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numItermin = 35 numItermax = max(numItermin, numItermax) # ensure that last velue is exact - cpt = 0 + ii = 0 if log: log = {'err': []} @@ -942,12 +919,10 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, def get_reg(n): # exponential decreasing return (epsilon0 - reg) * np.exp(-n) + reg - loop = 1 - cpt = 0 err = 1 - while loop: + for ii in range(numItermax): - regi = get_reg(cpt) + regi = get_reg(ii) G, logi = sinkhorn_stabilized(a, b, M, regi, numItermax=numInnerItermax, stopThr=1e-9, @@ -957,10 +932,7 @@ def get_reg(n): # exponential decreasing alpha = logi['alpha'] beta = logi['beta'] - if cpt >= numItermax: - loop = False - - if cpt % (print_period) == 0: # spsion nearly converged + if ii % (print_period) == 0: # spsion nearly converged # we can speed up the process by checking for the error only all # the 10th iterations transp = G @@ -969,15 +941,14 @@ def get_reg(n): # exponential decreasing log['err'].append(err) if verbose: - if cpt % (print_period * 10) == 0: + if ii % (print_period * 10) == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) - if err <= stopThr and cpt > numItermin: - loop = False + if err <= stopThr and ii > numItermin: + break - cpt = cpt + 1 - # print('err=',err,' cpt=',cpt) + # print('err=',err,' ii=',ii) if log: log['alpha'] = alpha log['beta'] = beta @@ -986,7 +957,6 @@ def get_reg(n): # exponential decreasing else: return G - def geometricBar(weights, alldistribT): """return the weighted geometric mean of distributions""" weights, alldistribT = list_to_array(weights, alldistribT) @@ -1023,11 +993,13 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = arg\min_\mathbf{a} \sum_i OT_{reg}(\mathbf{a},\mathbf{a}_i) where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`OT_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) + if `method` is `sinkhorn` or `sinkhorn_stabilized`. If `method`is `debiased`, :math:`OT_{reg}(\cdot,\cdot)` is the entropic + sinkhorn divergence (see :py:func:`ot.bregman.empirical_sinkhorn_divergence`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT @@ -1042,7 +1014,7 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, reg : float Regularization term > 0 method : str (optional) - method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' + method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'debiased' weights : array-like, shape (n_hists,) Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) numItermax : int, optional @@ -1081,6 +1053,11 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + elif method.lower() == 'debiased': + return barycenter_sinkhorn_debiased(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) @@ -1153,38 +1130,144 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, # M = M/np.median(M) # suggested by G. Peyre K = nx.exp(-M / reg) - cpt = 0 + err = 1 UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) u = (geometricMean(UKv) / UKv.T).T - while (err > stopThr and cpt < numItermax): - cpt = cpt + 1 + for ii in range(numItermax): + UKv = u * nx.dot(K, A / nx.dot(K, u)) u = (u.T * geometricBar(weights, UKv)).T / UKv - if cpt % 10 == 1: + if ii % 10 == 1: err = nx.sum(nx.std(UKv, axis=1)) # log and verbose print if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) if log: - log['niter'] = cpt + log['niter'] = ii return geometricBar(weights, UKv), log else: return geometricBar(weights, UKv) + +def barycenter_sinkhorn_debiased(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` + + Parameters + ---------- + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`a_i` of size `dim` + M : array-like, shape (dim, dim) + loss matrix for OT + reg : float + Regularization term > 0 + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (dim,) array-like + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-barycenter-sinkhorn: + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + + """ + + A, M = list_to_array(A, M) + + nx = get_backend(A, M) + + if weights is None: + weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + # M = M/np.median(M) # suggested by G. Peyre + K = nx.exp(-M / reg) + + err = 1 + + UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) + + u = (geometricMean(UKv) / UKv.T).T + c = nx.ones(A.shape[0], type_as=A) + for ii in range(numItermax): + UKv = nx.dot(K, A / nx.dot(K, u)) + bar = c * geometricBar(weights, UKv) + u = bar[:, None] / UKv + c = (c * bar / nx.dot(K, c)) ** 0.5 + + if ii % 10 == 1: + err = nx.sum(nx.std(UKv, axis=1)) + + # log and verbose print + if log: + log['err'].append(err) + + if err < stopThr: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + if log: + log['niter'] = ii + return bar, log + else: + return bar + + + def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): r"""Compute the entropic regularized wasserstein barycenter of distributions A with stabilization. @@ -1258,12 +1341,12 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, K = nx.exp(-M / reg) - cpt = 0 + err = 1. alpha = nx.zeros((dim,), type_as=M) beta = nx.zeros((dim,), type_as=M) q = nx.ones((dim,), type_as=M) / dim - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): qprev = q Kv = nx.dot(K, v) u = A / (Kv + 1e-16) @@ -1284,28 +1367,28 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % cpt) + warnings.warn('Numerical errors at iteration %s' % ii) q = qprev break - if (cpt % 10 == 0 and not absorbing) or cpt == 0: + if (ii % 10 == 0 and not absorbing) or ii == 0: # we can speed up the process by checking for the error only all # the 10th iterations err = nx.max(nx.abs(u * Kv - A)) if log: log['err'].append(err) - if verbose: - if cpt % 50 == 0: + if err < stopThr: + if ii % 50 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) - cpt += 1 + ii += 1 if err > stopThr: warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + "Try a larger entropy `reg`" + "Or a larger absorption threshold `tau`.") if log: - log['niter'] = cpt + log['niter'] = ii log['logu'] = np.log(u + 1e-16) log['logv'] = np.log(v + 1e-16) return q, log @@ -1313,9 +1396,79 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, return q -def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, - stopThr=1e-9, stabThr=1e-30, verbose=False, - log=False): +def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-5, verbose=False, log=False, **kwargs): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i OT_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`OT_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) + if `method` is `sinkhorn` or `sinkhorn_stabilized`. If `method`is `debiased`, :math:`OT_{reg}(\cdot,\cdot)` is the entropic + sinkhorn divergence (see :py:func:`ot.bregman.empirical_sinkhorn_divergence`) - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` + - `reg` is the regularization strength scalar value + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[21] ` + + Parameters + ---------- + A : array-like, shape (n_hists, width, height) + `n` distributions (2D images) of size `width` x `height` + reg : float + Regularization term >0 + weights : array-like, shape (n_hists,) + Weights of each image on the simplex (barycentric coodinates) + method : string, optional + method used for the solver either 'sinkhorn' or 'debiased' + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + stabThr : float, optional + Stabilization threshold to avoid numerical precision issue + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + a : array-like, shape (width, height) + 2D Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-convolutional-barycenter-2d: + References + ---------- + + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: Efficient optimal transportation on geometric domains. ACM Transactions on Graphics (TOG), 34(4), 66 + .. [28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + + if method.lower() == 'sinkhorn': + return convolutional_barycenter2d_sinkhorn(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'debiased': + return convolutional_barycenter2d_debiased(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def convolutional_barycenter2d_sinkhorn(A, reg, weights=None, numItermax=10000, + stopThr=1e-9, stabThr=1e-30, verbose=False, + log=False): r"""Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images. @@ -1380,61 +1533,178 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, if log: log = {'err': []} - b = nx.zeros(A.shape[1:], type_as=A) + bar = nx.ones(A.shape[1:], type_as=A) + bar /= bar.sum() U = nx.ones(A.shape, type_as=A) - KV = nx.ones(A.shape, type_as=A) + V = nx.ones(A.shape, type_as=A) + err = 1 + + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, A.shape[1]) + [Y, X] = nx.meshgrid(t, t) + K1 = nx.exp(-(X - Y) ** 2 / reg) + + t = nx.linspace(0, 1, A.shape[2]) + [Y, X] = nx.meshgrid(t, t) + K2 = nx.exp(-(X - Y) ** 2 / reg) + + def convol_imgs(imgs): + kx = nx.einsum("...ij,kjl->kil", K1, imgs) + kxy = nx.einsum("...ij,klj->kli", K2, kx) + return kxy + + KV = convol_imgs(V) + for ii in range(numItermax): + bold = bar + U = A / KV + KU = convol_imgs(U) + bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + V = bar[None] / KU + KV = convol_imgs(V) + if ii % 10 == 9: + err = abs(bold - bar).max() / max(1., bar.max()) + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + + if log: + log['niter'] = ii + log['U'] = U + return bar, log + else: + return bar + - cpt = 0 +def convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, + stopThr=1e-4, stabThr=1e-15, verbose=False, + log=False): + r"""Compute the debiased sinkhorn barycenter of distributions A + where A is a collection of 2D images. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`S_{reg}(\cdot,\cdot)` is the debiased entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn_debiased`) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` + - `reg` is the regularization strength scalar value + + The algorithm used for solving the problem is the debiased Sinkhorn scaling algorithm as proposed in :ref:`[28] ` + + Parameters + ---------- + A : array-like, shape (n_hists, width, height) + `n` distributions (2D images) of size `width` x `height` + reg : float + Regularization term >0 + weights : array-like, shape (n_hists,) + Weights of each image on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + stabThr : float, optional + Stabilization threshold to avoid numerical precision issue + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + a : array-like, shape (width, height) + 2D Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-sinkhorn-debiased: + References + ---------- + + .. [28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + + A = list_to_array(A) + + nx = get_backend(A) + + if weights is None: + weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0] + else: + assert (len(weights) == A.shape[0]) + + if log: + log = {'err': []} + + bar = nx.ones(A.shape[1:], type_as=A) + bar /= bar.sum() + U = nx.ones(A.shape, type_as=A) + V = nx.ones(A.shape, type_as=A) + c = nx.ones(A.shape[1:], type_as=A) err = 1 # build the convolution operator # this is equivalent to blurring on horizontal then vertical directions t = nx.linspace(0, 1, A.shape[1]) [Y, X] = nx.meshgrid(t, t) - xi1 = nx.exp(-(X - Y) ** 2 / reg) + K1 = nx.exp(-(X - Y) ** 2 / reg) t = nx.linspace(0, 1, A.shape[2]) [Y, X] = nx.meshgrid(t, t) - xi2 = nx.exp(-(X - Y) ** 2 / reg) - - def K(x): - return nx.dot(nx.dot(xi1, x), xi2) - - while (err > stopThr and cpt < numItermax): - - bold = b - cpt = cpt + 1 - - b = nx.zeros(A.shape[1:], type_as=A) - KV_cols = [] - for r in range(A.shape[0]): - KV_col_r = K(A[r, :, :] / nx.maximum(stabThr, K(U[r, :, :]))) - b += weights[r] * nx.log(nx.maximum(stabThr, U[r, :, :] * KV_col_r)) - KV_cols.append(KV_col_r) - KV = nx.stack(KV_cols) - b = nx.exp(b) - - U = nx.stack([ - b / nx.maximum(stabThr, KV[r, :, :]) - for r in range(A.shape[0]) - ]) - if cpt % 10 == 1: - err = nx.sum(nx.abs(bold - b)) + K2 = nx.exp(-(X - Y) ** 2 / reg) + + + def convol_imgs(imgs): + kx = nx.einsum("...ij,kjl->kil", K1, imgs) + kxy = nx.einsum("...ij,klj->kli", K2, kx) + return kxy + + KV = convol_imgs(V) + for ii in range(numItermax): + bold = bar + U = A / KV + KU = convol_imgs(U) + bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + V = bar[None] / KU + KV = convol_imgs(V) + for _ in range(10): + c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5 + + if ii % 10 == 9: + err = abs(bold - bar).max() / max(1., bar.max()) # log and verbose print if log: log['err'].append(err) if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) + + # debiased Sinkhorn does not converge monotonically + # guarantee a few iterations are done before stopping + if err < stopThr and ii > 20: + break if log: - log['niter'] = cpt + log['niter'] = ii log['U'] = U - return b, log + return bar, log else: - return b + return bar + def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, @@ -1519,12 +1789,11 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, old = h0 err = 1 - cpt = 0 # log = {'niter':0, 'all_err':[]} if log: log = {'err': []} - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): K = projC(K, a) K0 = projC(K0, h0) new = nx.sum(K0, axis=1) @@ -1542,14 +1811,13 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, log['err'].append(err) if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - cpt = cpt + 1 - + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break if log: - log['niter'] = cpt + log['niter'] = ii return nx.sum(K0, axis=1), log else: return nx.sum(K0, axis=1) @@ -1673,11 +1941,10 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, # uniform target distribution a = nx.from_numpy(unif(np.shape(Xt)[0])) - cpt = 0 # iterations count err = 1 old_bary = nx.ones((nbclasses,), type_as=Xs[0]) - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): bary = nx.zeros((nbclasses,), type_as=Xs[0]) @@ -1695,21 +1962,23 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, K[d] = projR(K[d], new) err = nx.norm(bary - old_bary) - cpt = cpt + 1 + old_bary = bary if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) bary = bary / nx.sum(bary) if log: - log['niter'] = cpt + log['niter'] = ii log['M'] = M log['D1'] = D1 log['D2'] = D2 @@ -2410,13 +2679,11 @@ def projection(u, epsilon): cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1) cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa - cpt = 1 - while cpt < 5: # 5 iterations + for ii in range(5): # 5 iterations K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v v0 = b_J / (kappa * K_IJ_v) KIJ_u = nx.dot(K_IJ, v0) + cst_u u0 = (kappa * a_I) / KIJ_u - cpt += 1 u0 = projection(u0, epsilon / kappa) v0 = projection(v0, epsilon * kappa) @@ -2429,13 +2696,11 @@ def restricted_sinkhorn(usc, vsc, max_iter=5): """ Restricted Sinkhorn Algorithm as a warm-start initialized point for L-BFGS-B (see Algorithm 1 in supplementary of [26]) """ - cpt = 1 - while cpt < max_iter: + for ii in range(max_iter): K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v vsc = b_J / (kappa * K_IJ_v) KIJ_u = nx.dot(K_IJ, vsc) + cst_u usc = (kappa * a_I) / KIJ_u - cpt += 1 usc = projection(usc, epsilon / kappa) vsc = projection(vsc, epsilon * kappa) From 598a7a74ffb7f20c40840985da84c09414710c25 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 26 Oct 2021 20:13:54 +0200 Subject: [PATCH 02/25] add debiased arg in tests --- test/test_bregman.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 942cb6d0d..ef5c1b2e4 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -184,7 +184,7 @@ def test_sinkhorn_variants_log(): print(G0, G_green) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "debiased"]) def test_barycenter(nx, method): n_bins = 100 # nb bins @@ -254,7 +254,8 @@ def test_barycenter_stabilization(nx): np.testing.assert_allclose(bar, bar_np) -def test_wasserstein_bary_2d(nx): +@pytest.mark.parametrize("method", ["sinkhorn", "debiased"]) +def test_wasserstein_bary_2d(nx, method): size = 100 # size of a square image a1 = np.random.randn(size, size) a1 += a1.min() @@ -271,11 +272,11 @@ def test_wasserstein_bary_2d(nx): # wasserstein reg = 1e-2 - bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, reg)) + bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)) - np.testing.assert_allclose(1, np.sum(bary_wass)) - np.testing.assert_allclose(bary_wass, bary_wass_np) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) # help in checking if log and verbose do not bug the function ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) From c69da39a2cd2c3b8f20a36111bc060756e431b15 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 26 Oct 2021 20:14:12 +0200 Subject: [PATCH 03/25] add 1d and 2d examples of debiased barycenters --- .../barycenters/plot_debiased_barycenter.py | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 examples/barycenters/plot_debiased_barycenter.py diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py new file mode 100644 index 000000000..1307b23f5 --- /dev/null +++ b/examples/barycenters/plot_debiased_barycenter.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +""" +============================== +Debiased Sinkhorn barycenter demo +============================== + +This example illustrates the computation of the debiased Sinkhorn barycenter +as proposed in [28]. + + +[28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th + International Conference on Machine Learning, PMLR 119:4692-4701, 2020 + +""" + +# Author: Hicham Janati +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as plt +import ot +from ot.bregman import barycenter, convolutional_barycenter2d +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D # noqa +from matplotlib.collections import PolyCollection + +############################################################################## +# Debiased barycenter of 1D Gaussians +# ------------------------------------ + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + +#%% barycenter computation + +alpha = 0.2 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +epsilons = [5e-3, 1e-2, 5e-2] + + +bars = [barycenter(A, M, reg, weights, method="sinkhorn") for reg in epsilons] +bars_debiased = [barycenter(A, M, reg, weights, method="debiased") for reg in epsilons] + +labels = ["Sinkhorn barycenter", "Debiased barycenter"] +colors = ["indianred", "gold"] + +f, axes = plt.subplots(1, len(epsilons), tight_layout=True, sharey=True, figsize=(12, 4)) +for ax, eps, bar, bar_debiased in zip(axes, epsilons, bars, bars_debiased): + ax.plot(A[:, 0], color="k", ls="--", label="Input data", alpha=0.3) + ax.plot(A[:, 1], color="k", ls="--", alpha=0.3) + for data, label, color in zip([bar, bar_debiased], labels, colors): + ax.plot(data, color=color, label=label, lw=2) + ax.set_title(r"$\varepsilon = %.3f$" % eps) +plt.legend() +plt.show() + + +############################################################################## +# Debiased barycenter of 2D images +# --------------------------------- + + + +f1 = 1 - plt.imread('../../data/redcross.png')[:, :, 2] +f2 = 1 - plt.imread('../../data/tooth.png')[:, :, 2] +f3 = 1 - plt.imread('../../data/duck.png')[:, :, 2] + +A = [] +f1 = f1 / np.sum(f1) +f2 = f2 / np.sum(f2) +f3 = f3 / np.sum(f3) + +A.append(f1) +A.append(f2) +A.append(f3) + +A = np.array(A) + +############################################################################## +# Display the input images + +f, axes = plt.subplots(1, 3, figsize=(12, 4)) +for ax, img in zip(axes, A): + ax.imshow(img, cmap="Greys") +plt.show() + + +############################################################################## +# Barycenter computation and visualization +# ---------------------------------------- +# + +bars_sinkhorn, bars_debiased = [], [] +epsilons = [5e-3, 7e-3, 1e-2] +for eps in epsilons: + bar = convolutional_barycenter2d(A, eps, method="sinkhorn") + bar_debiased = convolutional_barycenter2d(A, eps, method="debiased") + bars_sinkhorn.append(bar) + bars_debiased.append(bar_debiased) + +titles = ["Sinkhorn", "Debiased"] +all_bars = [bars_sinkhorn, bars_debiased] +f, axes = plt.subplots(2, 3, figsize=(12, 8)) +for jj, (method, ax_row, bars) in enumerate(zip(titles, axes, all_bars)): + for ii, (ax, img, eps) in enumerate(zip(ax_row, bars, epsilons)): + ax.imshow(img, cmap="Greys") + if jj == 0: + ax.set_title(r"$\varepsilon = %.3f$" % eps, fontsize=13) + ax.set_xticks([]) + ax.set_yticks([]) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['bottom'].set_visible(False) + ax.spines['left'].set_visible(False) + if ii == 0: + ax.set_ylabel(method, fontsize=15) +plt.show() \ No newline at end of file From 72966905456b5dc69c787b3dacda7a49ad5fd29a Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 26 Oct 2021 20:42:29 +0200 Subject: [PATCH 04/25] fix doctest --- ot/bregman.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ot/bregman.py b/ot/bregman.py index 278ccec00..91a1ef454 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -875,6 +875,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, >>> ot.bregman.sinkhorn_epsilon_scaling(a, b, M, 1) array([[0.36552929, 0.13447071], [0.13447071, 0.36552929]]) + .. _references-sinkhorn-epsilon-scaling: References ---------- From 3253d5504018668440c5e9f4d69849114927999b Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 26 Oct 2021 21:45:45 +0200 Subject: [PATCH 05/25] fix flake8 --- .../barycenters/plot_debiased_barycenter.py | 7 +----- ot/bregman.py | 23 ++++++------------- 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py index 1307b23f5..dc05fb0fa 100644 --- a/examples/barycenters/plot_debiased_barycenter.py +++ b/examples/barycenters/plot_debiased_barycenter.py @@ -23,9 +23,6 @@ import matplotlib.pylab as plt import ot from ot.bregman import barycenter, convolutional_barycenter2d -# necessary for 3d plot even if not used -from mpl_toolkits.mplot3d import Axes3D # noqa -from matplotlib.collections import PolyCollection ############################################################################## # Debiased barycenter of 1D Gaussians @@ -79,8 +76,6 @@ # Debiased barycenter of 2D images # --------------------------------- - - f1 = 1 - plt.imread('../../data/redcross.png')[:, :, 2] f2 = 1 - plt.imread('../../data/tooth.png')[:, :, 2] f3 = 1 - plt.imread('../../data/duck.png')[:, :, 2] @@ -134,4 +129,4 @@ ax.spines['left'].set_visible(False) if ii == 0: ax.set_ylabel(method, fontsize=15) -plt.show() \ No newline at end of file +plt.show() diff --git a/ot/bregman.py b/ot/bregman.py index 91a1ef454..74c7d56f1 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -383,7 +383,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, K = nx.exp(M / (-reg)) Kp = (1 / a).reshape(-1, 1) * K - + err = 1 for ii in range(numItermax): uprev = u @@ -420,7 +420,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) - + if log: log['u'] = u log['v'] = v @@ -694,7 +694,6 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, dim_a = len(a) dim_b = len(b) - if log: log = {'err': []} @@ -721,10 +720,8 @@ def get_Gamma(alpha, beta, u, v): return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) / reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b)))) - # print(np.min(K)) - K = get_K(alpha, beta) - transp = K + transp = K err = 1 for ii in range(numItermax): @@ -949,7 +946,6 @@ def get_reg(n): # exponential decreasing if err <= stopThr and ii > numItermin: break - # print('err=',err,' ii=',ii) if log: log['alpha'] = alpha log['beta'] = beta @@ -958,6 +954,7 @@ def get_reg(n): # exponential decreasing else: return G + def geometricBar(weights, alldistribT): """return the weighted geometric mean of distributions""" weights, alldistribT = list_to_array(weights, alldistribT) @@ -1131,7 +1128,6 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, # M = M/np.median(M) # suggested by G. Peyre K = nx.exp(-M / reg) - err = 1 UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) @@ -1139,7 +1135,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, u = (geometricMean(UKv) / UKv.T).T for ii in range(numItermax): - + UKv = u * nx.dot(K, A / nx.dot(K, u)) u = (u.T * geometricBar(weights, UKv)).T / UKv @@ -1165,7 +1161,6 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, return geometricBar(weights, UKv) - def barycenter_sinkhorn_debiased(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): r"""Compute the entropic regularized wasserstein barycenter of distributions A @@ -1268,7 +1263,6 @@ def barycenter_sinkhorn_debiased(A, M, reg, weights=None, numItermax=1000, return bar - def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): r"""Compute the entropic regularized wasserstein barycenter of distributions A with stabilization. @@ -1342,7 +1336,6 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, K = nx.exp(-M / reg) - err = 1. alpha = nx.zeros((dim,), type_as=M) beta = nx.zeros((dim,), type_as=M) @@ -1666,7 +1659,6 @@ def convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, [Y, X] = nx.meshgrid(t, t) K2 = nx.exp(-(X - Y) ** 2 / reg) - def convol_imgs(imgs): kx = nx.einsum("...ij,kjl->kil", K1, imgs) kxy = nx.einsum("...ij,klj->kli", K2, kx) @@ -1693,7 +1685,7 @@ def convol_imgs(imgs): if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) - + # debiased Sinkhorn does not converge monotonically # guarantee a few iterations are done before stopping if err < stopThr and ii > 20: @@ -1707,7 +1699,6 @@ def convol_imgs(imgs): return bar - def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False): r""" @@ -1963,7 +1954,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, K[d] = projR(K[d], new) err = nx.norm(bary - old_bary) - + old_bary = bary if log: From 6751b74edabf595e065984c17ef2fe50e96c24c4 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 27 Oct 2021 12:01:42 +0200 Subject: [PATCH 06/25] pep8 + make func private + add convergence warnings --- ot/bregman.py | 559 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 378 insertions(+), 181 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 74c7d56f1..2b8e995e7 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -42,8 +42,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) .. note:: This function is backend-compatible and will work on arrays from all compatible backends. @@ -73,7 +75,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -112,11 +115,14 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, + Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy + Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms + for unbalanced transport problems. arXiv preprint arXiv:1607.05816. @@ -125,8 +131,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` :ref:`[10] ` - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling :ref:`[9] ` :ref:`[10] ` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn + :ref:`[9] ` :ref:`[10] ` + ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling + :ref:`[9] ` :ref:`[10] ` """ @@ -168,13 +176,16 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) .. note:: This function is backend-compatible and will work on arrays from all compatible backends. - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[2] ` **Choosing a Sinkhorn solver** @@ -198,13 +209,15 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', + see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -237,13 +250,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation + algorithms for optimal transport via Sinkhorn iteration, Advances in Neural + Information Processing Systems (NIPS) 31, 2017 @@ -253,7 +271,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.optim.cg : General regularized OT ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` ot.bregman.greenkhorn : Greenkhorn :ref:`[21] ` - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` :ref:`[10] ` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` + :ref:`[10] ` """ @@ -291,10 +310,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp + matrix scaling algorithm as proposed in :ref:`[2] ` Parameters @@ -303,7 +325,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -340,7 +363,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 See Also @@ -420,8 +444,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) - + else: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: + log['n_iter'] = ii log['u'] = u log['v'] = v @@ -445,7 +473,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, r""" Solve the entropic regularization optimal transport problem and return the OT matrix - The algorithm used is based on the paper :ref:`[22] ` which is a stochastic version of the Sinkhorn-Knopp algorithm :ref:`[2] ` + The algorithm used is based on the paper :ref:`[22] ` + which is a stochastic version of the Sinkhorn-Knopp + algorithm :ref:`[2] ` The function solves the following optimization problem: @@ -460,8 +490,10 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) Parameters @@ -470,7 +502,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, samples weights in the source domain b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -505,9 +538,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, + Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms + for optimal transport via Sinkhorn iteration, Advances in Neural Information + Processing Systems (NIPS) 31, 2017 See Also @@ -521,7 +557,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, nx = get_backend(M, a, b) if nx.__name__ == "jax": - raise TypeError("JAX arrays have been received. Greenkhorn is not compatible with JAX") + raise TypeError("JAX arrays have been received. Greenkhorn is not " + "compatible with JAX") if len(a) == 0: a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] @@ -545,7 +582,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log['u'] = u log['v'] = v - for i in range(numItermax): + for ii in range(numItermax): i_1 = nx.argmax(nx.abs(viol)) i_2 = nx.argmax(nx.abs(viol_2)) m_viol_1 = nx.abs(viol[i_1]) @@ -569,14 +606,16 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, viol += (-old_v + new_v) * K[:, i_2] * u viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2] v[i_2] = new_v - # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) if stopThr_val <= stopThr: break else: - print('Warning: Algorithm did not converge') + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: + log["n_iter"] = ii log['u'] = u log['v'] = v @@ -605,13 +644,17 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in :ref:`[2] ` but with the log stabilization - proposed in :ref:`[10] ` an defined in :ref:`[9] ` (Algo 3.1) . + scaling algorithm as proposed in :ref:`[2] ` + but with the log stabilization + proposed in :ref:`[10] ` an defined in + :ref:`[9] ` (Algo 3.1) . Parameters @@ -625,7 +668,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, reg : float Regularization term >0 tau : float - threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` + for log scaling warmstart : table of vectors if given then starting values for alpha and beta log scalings numItermax : int, optional @@ -660,11 +704,15 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. See Also @@ -708,7 +756,9 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, u = nx.ones((dim_a, n_hists), type_as=M) / dim_a v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u, v = nx.ones(dim_a, type_as=M) / dim_a, nx.ones(dim_b, type_as=M) / dim_b + u, v = nx.ones(dim_a, type_as=M), nx.ones(dim_b, type_as=M) + u /= dim_a + v /= dim_b def get_K(alpha, beta): """log space computation""" @@ -775,17 +825,21 @@ def get_Gamma(alpha, beta, u, v): if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', ii) + warnings.warn('Numerical errors at iteration', ii) u = uprev v = vprev break - + else: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: if n_hists: alpha = alpha[:, None] beta = beta[:, None] logu = alpha / reg + nx.log(u) logv = beta / reg + nx.log(v) + log["n_iter"] = ii log['logu'] = logu log['logv'] = logv log['alpha'] = alpha + reg * nx.log(u) @@ -826,11 +880,16 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, \gamma\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights + (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in :ref:`[2] ` but with the log stabilization - proposed in :ref:`[10] ` and the log scaling proposed in :ref:`[9] ` algorithm 3.2 + scaling algorithm as proposed in :ref:`[2] ` + but with the log stabilization + proposed in :ref:`[10] ` and the log scaling + proposed in :ref:`[9] ` algorithm 3.2 + Parameters ---------- a : array-like, shape (dim_a,) @@ -842,7 +901,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, reg : float Regularization term >0 tau : float - threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}` for log scaling + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}` + for log scaling warmstart : tuple of vectors if given then starting values for alpha and beta log scalings numItermax : int, optional @@ -876,9 +936,12 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, .. _references-sinkhorn-epsilon-scaling: References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. See Also -------- ot.lp.emd : Unregularized OT @@ -945,11 +1008,15 @@ def get_reg(n): # exponential decreasing if err <= stopThr and ii > numItermin: break - + else: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['alpha'] = alpha log['beta'] = beta log['warmstart'] = (log['alpha'], log['beta']) + log['n_iter'] = ii return G, log else: return G @@ -995,13 +1062,18 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, where : - - :math:`OT_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - if `method` is `sinkhorn` or `sinkhorn_stabilized`. If `method`is `debiased`, :math:`OT_{reg}(\cdot,\cdot)` is the entropic + - :math:`OT_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + if `method` is `sinkhorn` or `sinkhorn_stabilized`. If `method`is `debiased`, + :math:`OT_{reg}(\cdot,\cdot)` is the entropic sinkhorn divergence (see :py:func:`ot.bregman.empirical_sinkhorn_divergence`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling + algorithm as proposed in :ref:`[3] ` Parameters ---------- @@ -1037,7 +1109,9 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ @@ -1052,10 +1126,10 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'debiased': - return barycenter_sinkhorn_debiased(A, M, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return _barycenter_sinkhorn_debiased(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) @@ -1071,11 +1145,15 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[3]`. Parameters ---------- @@ -1109,7 +1187,9 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ @@ -1125,7 +1205,6 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, if log: log = {'err': []} - # M = M/np.median(M) # suggested by G. Peyre K = nx.exp(-M / reg) err = 1 @@ -1153,7 +1232,10 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) - + else: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['niter'] = ii return geometricBar(weights, UKv), log @@ -1161,8 +1243,8 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, return geometricBar(weights, UKv) -def barycenter_sinkhorn_debiased(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): +def _barycenter_sinkhorn_debiased(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False): r"""Compute the entropic regularized wasserstein barycenter of distributions A The function solves the following optimization problem: @@ -1172,11 +1254,15 @@ def barycenter_sinkhorn_debiased(A, M, reg, weights=None, numItermax=1000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term + and the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[3] ` Parameters ---------- @@ -1210,7 +1296,9 @@ def barycenter_sinkhorn_debiased(A, M, reg, weights=None, numItermax=1000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ @@ -1226,7 +1314,6 @@ def barycenter_sinkhorn_debiased(A, M, reg, weights=None, numItermax=1000, if log: log = {'err': []} - # M = M/np.median(M) # suggested by G. Peyre K = nx.exp(-M / reg) err = 1 @@ -1255,7 +1342,10 @@ def barycenter_sinkhorn_debiased(A, M, reg, weights=None, numItermax=1000, print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) - + else: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['niter'] = ii return bar, log @@ -1265,7 +1355,7 @@ def barycenter_sinkhorn_debiased(A, M, reg, weights=None, numItermax=1000, def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A with stabilization. + r"""Compute the entropic wasserstein barycenter with stabilization. The function solves the following optimization problem: @@ -1274,11 +1364,15 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling + algorithm as proposed in :ref:`[3] ` Parameters ---------- @@ -1289,7 +1383,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, reg : float Regularization term > 0 tau : float - threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` + for log scaling weights : array-like, shape (n_hists,) Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) numItermax : int, optional @@ -1314,7 +1409,9 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ @@ -1376,8 +1473,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) - ii += 1 - if err > stopThr: + else: warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + "Try a larger entropy `reg`" + "Or a larger absorption threshold `tau`.") @@ -1402,12 +1498,17 @@ def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numIterm where : - - :math:`OT_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - if `method` is `sinkhorn` or `sinkhorn_stabilized`. If `method`is `debiased`, :math:`OT_{reg}(\cdot,\cdot)` is the entropic - sinkhorn divergence (see :py:func:`ot.bregman.empirical_sinkhorn_divergence`) - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` + - :math:`OT_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + if `method` is `sinkhorn` or `sinkhorn_stabilized`. If `method`is `debiased`, + :math:`OT_{reg}(\cdot,\cdot)` is the entropic + sinkhorn divergence (see :py:func:`ot.bregman.empirical_sinkhorn_divergence`) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions + of matrix :math:`\mathbf{A}` - `reg` is the regularization strength scalar value - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[21] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm + as proposed in :ref:`[21] ` Parameters ---------- @@ -1442,27 +1543,32 @@ def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numIterm References ---------- - .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: Efficient optimal transportation on geometric domains. ACM Transactions on Graphics (TOG), 34(4), 66 - .. [28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, + A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: + Efficient optimal transportation on geometric domains. ACM Transactions + on Graphics (TOG), 34(4), 66 + .. [28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th + International Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ if method.lower() == 'sinkhorn': - return convolutional_barycenter2d_sinkhorn(A, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - **kwargs) + return _convolutional_barycenter2d_sinkhorn(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, + **kwargs) elif method.lower() == 'debiased': - return convolutional_barycenter2d_debiased(A, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return _convolutional_barycenter2d_debiased(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) -def convolutional_barycenter2d_sinkhorn(A, reg, weights=None, numItermax=10000, - stopThr=1e-9, stabThr=1e-30, verbose=False, - log=False): +def _convolutional_barycenter2d_sinkhorn(A, reg, weights=None, numItermax=10000, + stopThr=1e-9, stabThr=1e-30, verbose=False, + log=False): r"""Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images. @@ -1473,11 +1579,14 @@ def convolutional_barycenter2d_sinkhorn(A, reg, weights=None, numItermax=10000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two + dimensions of matrix :math:`\mathbf{A}` - `reg` is the regularization strength scalar value - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[21] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling + algorithm as proposed in :ref:`[21] ` Parameters ---------- @@ -1510,7 +1619,10 @@ def convolutional_barycenter2d_sinkhorn(A, reg, weights=None, numItermax=10000, References ---------- - .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: Efficient optimal transportation on geometric domains. ACM Transactions on Graphics (TOG), 34(4), 66 + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, + A. & Guibas, L. (2015). Convolutional wasserstein distances: Efficient + optimal transportation on geometric domains. ACM Transactions on Graphics + (TOG), 34(4), 66 """ @@ -1568,7 +1680,10 @@ def convol_imgs(imgs): print('{:5d}|{:8e}|'.format(ii, err)) if err < stopThr: break - + else: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") if log: log['niter'] = ii log['U'] = U @@ -1577,9 +1692,9 @@ def convol_imgs(imgs): return bar -def convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, - stopThr=1e-4, stabThr=1e-15, verbose=False, - log=False): +def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, + stopThr=1e-4, stabThr=1e-15, verbose=False, + log=False): r"""Compute the debiased sinkhorn barycenter of distributions A where A is a collection of 2D images. @@ -1590,11 +1705,14 @@ def convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, where : - - :math:`S_{reg}(\cdot,\cdot)` is the debiased entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn_debiased`) - - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` + - :math:`S_{reg}(\cdot,\cdot)` is the debiased entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn_debiased`) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two + dimensions of matrix :math:`\mathbf{A}` - `reg` is the regularization strength scalar value - The algorithm used for solving the problem is the debiased Sinkhorn scaling algorithm as proposed in :ref:`[28] ` + The algorithm used for solving the problem is the debiased Sinkhorn scaling + algorithm as proposed in :ref:`[28] ` Parameters ---------- @@ -1627,7 +1745,8 @@ def convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, References ---------- - .. [28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 + .. [28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ A = list_to_array(A) @@ -1690,7 +1809,10 @@ def convol_imgs(imgs): # guarantee a few iterations are done before stopping if err < stopThr and ii > 20: break - + else: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['niter'] = ii log['U'] = U @@ -1713,16 +1835,21 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, where : - - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)` + - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + with M loss matrix (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, + its expected shape is `(dim_a, n_atoms)` - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms` - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a` - :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (`dim_a`, `dim_a`) for OT data fitting - - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization term and the cost matrix (`dim_prior`, `n_atoms`) regularization + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the + cost matrix (`dim_a`, `dim_a`) for OT data fitting + - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization + term and the cost matrix (`dim_prior`, `n_atoms`) regularization - :math:`\\alpha` weight data fitting and regularization - The optimization problem is solved following the algorithm described in :ref:`[4] ` + The optimization problem is solved following the algorithm described + in :ref:`[4] ` Parameters @@ -1765,7 +1892,9 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, References ---------- - .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016. + .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary + unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : + Evolution in Remote Sensing (WHISPERS), 2016. """ @@ -1808,6 +1937,10 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, print('{:5d}|{:8e}|'.format(ii, err)) if err < stopThr: break + else: + warnings.warn("Unmixing algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['niter'] = ii return nx.sum(K0, axis=1), log @@ -1817,7 +1950,8 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, stopThr=1e-6, verbose=False, log=False, **kwargs): - r'''Joint OT and proportion estimation for multi-source target shift as proposed in :ref:`[27] ` + r'''Joint OT and proportion estimation for multi-source target shift as + proposed in :ref:`[27] ` The function solves the following optimization problem: @@ -1831,16 +1965,23 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, where : - :math:`\lambda_k` is the weight of `k`-th source domain - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain defined as in [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source domain and `C` is the number of classes + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain + defined as in [p. 5, :ref:`27 `], its expected shape + is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source + domain and `C` is the number of classes - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C` - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` - - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` + - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in + [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` - The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. + The problem consist in solving a Wasserstein barycenter problem to estimate + the proportions :math:`\mathbf{h}` in the target domain. The algorithm used for solving the problem is the Iterative Bregman projections algorithm - with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution. + with two sets of marginal constraints related to the unknown vector + :math:`\mathbf{h}` and uniform target distribution. Parameters ---------- @@ -1966,7 +2107,10 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) - + else: + warnings.warn("Algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") bary = bary / nx.sum(bary) if log: @@ -2000,7 +2144,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) @@ -2021,7 +2166,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', stopThr : float, optional Stop threshold on error (>0) isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function. + If True, then only calculate the cost matrix by block and return + the dual potentials only (to save memory). If False, calculate full + cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional Size of the batches used to compute the sinkhorn update without memory overhead. When a tuple is provided it sets the size of the left/right batches. @@ -2054,11 +2201,14 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' X_s, X_t = list_to_array(X_s, X_t) @@ -2133,7 +2283,10 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if err <= stopThr: break - + else: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: dict_log["u"] = f dict_log["v"] = g @@ -2145,15 +2298,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric) M = nx.from_numpy(M, type_as=a) if log: - pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + verbose=verbose, log=True, **kwargs) return pi, log else: - pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + verbose=verbose, log=False, **kwargs) return pi -def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, - isLazy=False, batchSize=100, verbose=False, log=False, **kwargs): +def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', + numIterMax=10000, stopThr=1e-9, isLazy=False, + batchSize=100, verbose=False, log=False, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -2172,7 +2328,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) @@ -2193,7 +2350,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num stopThr : float, optional Stop threshold on error (>0) isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function. + If True, then only calculate the cost matrix by block and return + the dual potentials only (to save memory). If False, calculate + full cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional Size of the batches used to compute the sinkhorn update without memory overhead. When a tuple is provided it sets the size of the left/right batches. @@ -2226,11 +2385,14 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' X_s, X_t = list_to_array(X_s, X_t) @@ -2245,11 +2407,17 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num if isLazy: if log: - f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, - isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, + numIterMax=numIterMax, + stopThr=stopThr, + isLazy=isLazy, + batchSize=batchSize, + verbose=verbose, log=log) else: - f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, - isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, + numIterMax=numIterMax, stopThr=stopThr, + isLazy=isLazy, batchSize=batchSize, + verbose=verbose, log=log) bs = batchSize if isinstance(batchSize, int) else batchSize[0] range_s = range(0, ns, bs) @@ -2275,16 +2443,19 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num M = nx.from_numpy(M, type_as=a) if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) return sinkhorn_loss, log else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) return sinkhorn_loss -def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, +def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', + numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): r''' Compute the sinkhorn divergence loss from empirical data @@ -2322,8 +2493,11 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli \gamma_b\geq 0 where : - - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) is the (`n_samples_a`, `n_samples_b`) metric cost matrix (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) + is the (`n_samples_a`, `n_samples_b`) metric cost matrix + (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) @@ -2368,17 +2542,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli References ---------- - .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 + .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative + Models with Sinkhorn Divergences, Proceedings of the Twenty-First + International Conference on Artficial Intelligence and Statistics, + (AISTATS) 21, 2018 ''' if log: - sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, + numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, + log=log, **kwargs) - sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, + numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, + log=log, **kwargs) - sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, + numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, + log=log, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) @@ -2393,13 +2576,16 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli return max(0, sinkhorn_div), log else: - sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, + numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, + numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, + numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) @@ -2411,7 +2597,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res r""" Screening Sinkhorn Algorithm for Regularized Optimal Transport - The function solves an approximate dual of Sinkhorn divergence :ref:`[2] ` which is written as the following optimization problem: + The function solves an approximate dual of Sinkhorn divergence :ref:`[2] + ` which is written as the following optimization problem: .. math:: @@ -2429,7 +2616,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res e^{v_j} \geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\} - The parameters `kappa` and `epsilon` are determined w.r.t the couple number budget of points (`ns_budget`, `nt_budget`), see Equation (5) in :ref:`[26] ` + The parameters `kappa` and `epsilon` are determined w.r.t the couple number + budget of points (`ns_budget`, `nt_budget`), see Equation (5) + in :ref:`[26] ` Parameters @@ -2455,7 +2644,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res If it is None then 50% of the target sample points will be kept uniform : `bool`, default=False - If `True`, the source and target distribution are supposed to be uniform, i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt` + If `True`, the source and target distribution are supposed to be uniform, + i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt` restricted : `bool`, default=True If `True`, a warm-start initialization for the L-BFGS-B solver @@ -2471,14 +2661,16 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res Final objective function accuracy in LBFGS solver verbose : `bool`, default=False - If `True`, display informations about the cardinals of the active sets and the parameters kappa - and epsilon + If `True`, display informations about the cardinals of the active sets + and the parameters kappa and epsilon Dependency ---------- - To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/) - in the screening pre-processing step. If Bottleneck isn't installed, the following error message appears: + To gain more efficiency, screenkhorn needs to call the "Bottleneck" + package (https://pypi.org/project/Bottleneck/) + in the screening pre-processing step. If Bottleneck isn't installed, + the following error message appears: "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/" @@ -2495,9 +2687,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res References ----------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, + Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 + .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). + Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 """ # check if bottleneck module exists @@ -2505,14 +2699,16 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res import bottleneck except ImportError: warnings.warn( - "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.") + "Bottleneck module is not installed. Install it from" + " https://pypi.org/project/Bottleneck/ for better performance.") bottleneck = np a, b, M = list_to_array(a, b, M) nx = get_backend(M, a, b) if nx.__name__ == "jax": - raise TypeError("JAX arrays have been received but screenkhorn is not compatible with JAX.") + raise TypeError("JAX arrays have been received but screenkhorn is not " + "compatible with JAX.") ns, nt = M.shape @@ -2614,7 +2810,8 @@ def projection(u, epsilon): if verbose: print("epsilon = %s\n" % epsilon) print("kappa = %s\n" % kappa) - print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' % (sum(Isel), sum(Jsel))) + print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' + % (sum(Isel), sum(Jsel))) # Ic, Jc: complementary of the active sets I and J Ic = ~Isel @@ -2671,7 +2868,7 @@ def projection(u, epsilon): cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1) cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa - for ii in range(5): # 5 iterations + for _ in range(5): # 5 iterations K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v v0 = b_J / (kappa * K_IJ_v) KIJ_u = nx.dot(K_IJ, v0) + cst_u @@ -2686,9 +2883,9 @@ def projection(u, epsilon): def restricted_sinkhorn(usc, vsc, max_iter=5): """ - Restricted Sinkhorn Algorithm as a warm-start initialized point for L-BFGS-B (see Algorithm 1 in supplementary of [26]) + Restricted Sinkhorn Algorithm as a warm-start initialized pointfor L-BFGS-B) """ - for ii in range(max_iter): + for _ in range(max_iter): K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v vsc = b_J / (kappa * K_IJ_v) KIJ_u = nx.dot(K_IJ, vsc) + cst_u From c8d0e340d811b62eb564e1b82d70dc0fea063448 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 27 Oct 2021 12:27:53 +0200 Subject: [PATCH 07/25] remove rel paths + add rng + pylab to pyplot --- .../plot_convolutional_barycenter.py | 43 ++++--- .../barycenters/plot_debiased_barycenter.py | 12 +- .../plot_otda_color_images.py | 116 ++++++++--------- .../plot_otda_linear_mapping.py | 71 ++++++----- .../plot_otda_mapping_colors_images.py | 118 +++++++++--------- examples/gromov/plot_gromov_barycenter.py | 89 ++++++------- 6 files changed, 234 insertions(+), 215 deletions(-) diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py index cbcd4a13f..1f32adfad 100644 --- a/examples/barycenters/plot_convolutional_barycenter.py +++ b/examples/barycenters/plot_convolutional_barycenter.py @@ -13,10 +13,11 @@ # Author: Nicolas Courty # # License: MIT License - +import os +from pathlib import Path import numpy as np -import pylab as pl +from matplotlib import pyplot as plt import ot ############################################################################## @@ -25,11 +26,13 @@ # # The four distributions are constructed from 4 simple images +data_path = os.path.join(Path(__file__).parent.parent.parent, 'data') + +f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2] +f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2] +f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2] +f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] -f1 = 1 - pl.imread('../../data/redcross.png')[:, :, 2] -f2 = 1 - pl.imread('../../data/duck.png')[:, :, 2] -f3 = 1 - pl.imread('../../data/heart.png')[:, :, 2] -f4 = 1 - pl.imread('../../data/tooth.png')[:, :, 2] A = [] f1 = f1 / np.sum(f1) @@ -57,14 +60,14 @@ # ---------------------------------------- # -pl.figure(figsize=(10, 10)) -pl.title('Convolutional Wasserstein Barycenters in POT') +plt.figure(figsize=(10, 10)) +plt.title('Convolutional Wasserstein Barycenters in POT') cm = 'Blues' # regularization parameter reg = 0.004 for i in range(nb_images): for j in range(nb_images): - pl.subplot(nb_images, nb_images, i * nb_images + j + 1) + plt.subplot(nb_images, nb_images, i * nb_images + j + 1) tx = float(i) / (nb_images - 1) ty = float(j) / (nb_images - 1) @@ -74,19 +77,19 @@ weights = (1 - ty) * tmp1 + ty * tmp2 if i == 0 and j == 0: - pl.imshow(f1, cmap=cm) - pl.axis('off') + plt.imshow(f1, cmap=cm) + plt.axis('off') elif i == 0 and j == (nb_images - 1): - pl.imshow(f3, cmap=cm) - pl.axis('off') + plt.imshow(f3, cmap=cm) + plt.axis('off') elif i == (nb_images - 1) and j == 0: - pl.imshow(f2, cmap=cm) - pl.axis('off') + plt.imshow(f2, cmap=cm) + plt.axis('off') elif i == (nb_images - 1) and j == (nb_images - 1): - pl.imshow(f4, cmap=cm) - pl.axis('off') + plt.imshow(f4, cmap=cm) + plt.axis('off') else: # call to barycenter computation - pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm) - pl.axis('off') -pl.show() + plt.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm) + plt.axis('off') +plt.show() diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py index dc05fb0fa..6f9eef452 100644 --- a/examples/barycenters/plot_debiased_barycenter.py +++ b/examples/barycenters/plot_debiased_barycenter.py @@ -18,6 +18,8 @@ # License: MIT License # sphinx_gallery_thumbnail_number = 4 +import os +from pathlib import Path import numpy as np import matplotlib.pylab as plt @@ -57,7 +59,6 @@ bars = [barycenter(A, M, reg, weights, method="sinkhorn") for reg in epsilons] bars_debiased = [barycenter(A, M, reg, weights, method="debiased") for reg in epsilons] - labels = ["Sinkhorn barycenter", "Debiased barycenter"] colors = ["indianred", "gold"] @@ -75,10 +76,10 @@ ############################################################################## # Debiased barycenter of 2D images # --------------------------------- - -f1 = 1 - plt.imread('../../data/redcross.png')[:, :, 2] -f2 = 1 - plt.imread('../../data/tooth.png')[:, :, 2] -f3 = 1 - plt.imread('../../data/duck.png')[:, :, 2] +data_path = os.path.join(Path(__file__).parent.parent.parent, 'data') +f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2] +f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2] +f3 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] A = [] f1 = f1 / np.sum(f1) @@ -97,6 +98,7 @@ f, axes = plt.subplots(1, 3, figsize=(12, 4)) for ax, img in zip(axes, A): ax.imshow(img, cmap="Greys") + ax.axis("off") plt.show() diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py index 6218b1333..84d0d7f4d 100644 --- a/examples/domain-adaptation/plot_otda_color_images.py +++ b/examples/domain-adaptation/plot_otda_color_images.py @@ -19,12 +19,15 @@ # sphinx_gallery_thumbnail_number = 2 +import os +from pathlib import Path + import numpy as np -import matplotlib.pylab as pl +from matplotlib import pyplot as plt import ot -r = np.random.RandomState(42) +rng = np.random.RandomState(42) def im2mat(img): @@ -46,16 +49,17 @@ def minmax(img): # ------------- # Loading images -I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 +data_path = os.path.join(Path(__file__).parent.parent.parent, 'data') +I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 X1 = im2mat(I1) X2 = im2mat(I2) # training samples nb = 500 -idx1 = r.randint(X1.shape[0], size=(nb,)) -idx2 = r.randint(X2.shape[0], size=(nb,)) +idx1 = rng.randint(X1.shape[0], size=(nb,)) +idx2 = rng.randint(X2.shape[0], size=(nb,)) Xs = X1[idx1, :] Xt = X2[idx2, :] @@ -65,39 +69,39 @@ def minmax(img): # Plot original image # ------------------- -pl.figure(1, figsize=(6.4, 3)) +plt.figure(1, figsize=(6.4, 3)) -pl.subplot(1, 2, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Image 1') +plt.subplot(1, 2, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.imshow(I2) -pl.axis('off') -pl.title('Image 2') +plt.subplot(1, 2, 2) +plt.imshow(I2) +plt.axis('off') +plt.title('Image 2') ############################################################################## # Scatter plot of colors # ---------------------- -pl.figure(2, figsize=(6.4, 3)) +plt.figure(2, figsize=(6.4, 3)) -pl.subplot(1, 2, 1) -pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 1') +plt.subplot(1, 2, 1) +plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 2') -pl.tight_layout() +plt.subplot(1, 2, 2) +plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 2') +plt.tight_layout() ############################################################################## @@ -130,37 +134,37 @@ def minmax(img): # Plot new images # --------------- -pl.figure(3, figsize=(8, 4)) +plt.figure(3, figsize=(8, 4)) -pl.subplot(2, 3, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Image 1') +plt.subplot(2, 3, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Image 1') -pl.subplot(2, 3, 2) -pl.imshow(I1t) -pl.axis('off') -pl.title('Image 1 Adapt') +plt.subplot(2, 3, 2) +plt.imshow(I1t) +plt.axis('off') +plt.title('Image 1 Adapt') -pl.subplot(2, 3, 3) -pl.imshow(I1te) -pl.axis('off') -pl.title('Image 1 Adapt (reg)') +plt.subplot(2, 3, 3) +plt.imshow(I1te) +plt.axis('off') +plt.title('Image 1 Adapt (reg)') -pl.subplot(2, 3, 4) -pl.imshow(I2) -pl.axis('off') -pl.title('Image 2') +plt.subplot(2, 3, 4) +plt.imshow(I2) +plt.axis('off') +plt.title('Image 2') -pl.subplot(2, 3, 5) -pl.imshow(I2t) -pl.axis('off') -pl.title('Image 2 Adapt') +plt.subplot(2, 3, 5) +plt.imshow(I2t) +plt.axis('off') +plt.title('Image 2 Adapt') -pl.subplot(2, 3, 6) -pl.imshow(I2te) -pl.axis('off') -pl.title('Image 2 Adapt (reg)') -pl.tight_layout() +plt.subplot(2, 3, 6) +plt.imshow(I2te) +plt.axis('off') +plt.title('Image 2 Adapt (reg)') +plt.tight_layout() -pl.show() +plt.show() diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py index be475107c..df36fc00c 100644 --- a/examples/domain-adaptation/plot_otda_linear_mapping.py +++ b/examples/domain-adaptation/plot_otda_linear_mapping.py @@ -13,9 +13,11 @@ # License: MIT License # sphinx_gallery_thumbnail_number = 2 +import os +from pathlib import Path import numpy as np -import pylab as pl +from matplotlib import pyplot as plt import ot ############################################################################## @@ -26,17 +28,19 @@ d = 2 sigma = .1 +rng = np.random.RandomState(42) + # source samples -angles = np.random.rand(n, 1) * 2 * np.pi +angles = rng.rand(n, 1) * 2 * np.pi xs = np.concatenate((np.sin(angles), np.cos(angles)), - axis=1) + sigma * np.random.randn(n, 2) + axis=1) + sigma * rng.randn(n, 2) xs[:n // 2, 1] += 2 # target samples -anglet = np.random.rand(n, 1) * 2 * np.pi +anglet = rng.rand(n, 1) * 2 * np.pi xt = np.concatenate((np.sin(anglet), np.cos(anglet)), - axis=1) + sigma * np.random.randn(n, 2) + axis=1) + sigma * rng.randn(n, 2) xt[:n // 2, 1] += 2 @@ -48,9 +52,9 @@ # Plot data # --------- -pl.figure(1, (5, 5)) -pl.plot(xs[:, 0], xs[:, 1], '+') -pl.plot(xt[:, 0], xt[:, 1], 'o') +plt.figure(1, (5, 5)) +plt.plot(xs[:, 0], xs[:, 1], '+') +plt.plot(xt[:, 0], xt[:, 1], 'o') ############################################################################## @@ -66,13 +70,13 @@ # Plot transported samples # ------------------------ -pl.figure(1, (5, 5)) -pl.clf() -pl.plot(xs[:, 0], xs[:, 1], '+') -pl.plot(xt[:, 0], xt[:, 1], 'o') -pl.plot(xst[:, 0], xst[:, 1], '+') +plt.figure(1, (5, 5)) +plt.clf() +plt.plot(xs[:, 0], xs[:, 1], '+') +plt.plot(xt[:, 0], xt[:, 1], 'o') +plt.plot(xst[:, 0], xst[:, 1], '+') -pl.show() +plt.show() ############################################################################## # Load image data @@ -94,8 +98,9 @@ def minmax(img): # Loading images -I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 +data_path = os.path.join(Path(__file__).parent.parent.parent, 'data') +I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 X1 = im2mat(I1) @@ -123,24 +128,24 @@ def minmax(img): # Plot transformed images # ----------------------- -pl.figure(2, figsize=(10, 7)) +plt.figure(2, figsize=(10, 7)) -pl.subplot(2, 2, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Im. 1') +plt.subplot(2, 2, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Im. 1') -pl.subplot(2, 2, 2) -pl.imshow(I2) -pl.axis('off') -pl.title('Im. 2') +plt.subplot(2, 2, 2) +plt.imshow(I2) +plt.axis('off') +plt.title('Im. 2') -pl.subplot(2, 2, 3) -pl.imshow(I1t) -pl.axis('off') -pl.title('Mapping Im. 1') +plt.subplot(2, 2, 3) +plt.imshow(I1t) +plt.axis('off') +plt.title('Mapping Im. 1') -pl.subplot(2, 2, 4) -pl.imshow(I2t) -pl.axis('off') -pl.title('Inverse mapping Im. 2') +plt.subplot(2, 2, 4) +plt.imshow(I2t) +plt.axis('off') +plt.title('Inverse mapping Im. 2') diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py index 72010a674..d8b1ff36a 100644 --- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py +++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py @@ -21,12 +21,14 @@ # License: MIT License # sphinx_gallery_thumbnail_number = 3 +import os +from pathlib import Path import numpy as np -import matplotlib.pylab as pl +from matplotlib import pyplot as plt import ot -r = np.random.RandomState(42) +rng = np.random.RandomState(42) def im2mat(img): @@ -48,17 +50,17 @@ def minmax(img): # ------------- # Loading images -I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 - +data_path = os.path.join(Path(__file__).parent.parent.parent, 'data') +I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 X1 = im2mat(I1) X2 = im2mat(I2) # training samples nb = 500 -idx1 = r.randint(X1.shape[0], size=(nb,)) -idx2 = r.randint(X2.shape[0], size=(nb,)) +idx1 = rng.randint(X1.shape[0], size=(nb,)) +idx2 = rng.randint(X2.shape[0], size=(nb,)) Xs = X1[idx1, :] Xt = X2[idx2, :] @@ -99,76 +101,76 @@ def minmax(img): # Plot original images # -------------------- -pl.figure(1, figsize=(6.4, 3)) -pl.subplot(1, 2, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Image 1') +plt.figure(1, figsize=(6.4, 3)) +plt.subplot(1, 2, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.imshow(I2) -pl.axis('off') -pl.title('Image 2') -pl.tight_layout() +plt.subplot(1, 2, 2) +plt.imshow(I2) +plt.axis('off') +plt.title('Image 2') +plt.tight_layout() ############################################################################## # Plot pixel values distribution # ------------------------------ -pl.figure(2, figsize=(6.4, 5)) +plt.figure(2, figsize=(6.4, 5)) -pl.subplot(1, 2, 1) -pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 1') +plt.subplot(1, 2, 1) +plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 2') -pl.tight_layout() +plt.subplot(1, 2, 2) +plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 2') +plt.tight_layout() ############################################################################## # Plot transformed images # ----------------------- -pl.figure(2, figsize=(10, 5)) +plt.figure(2, figsize=(10, 5)) -pl.subplot(2, 3, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Im. 1') +plt.subplot(2, 3, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Im. 1') -pl.subplot(2, 3, 4) -pl.imshow(I2) -pl.axis('off') -pl.title('Im. 2') +plt.subplot(2, 3, 4) +plt.imshow(I2) +plt.axis('off') +plt.title('Im. 2') -pl.subplot(2, 3, 2) -pl.imshow(Image_emd) -pl.axis('off') -pl.title('EmdTransport') +plt.subplot(2, 3, 2) +plt.imshow(Image_emd) +plt.axis('off') +plt.title('EmdTransport') -pl.subplot(2, 3, 5) -pl.imshow(Image_sinkhorn) -pl.axis('off') -pl.title('SinkhornTransport') +plt.subplot(2, 3, 5) +plt.imshow(Image_sinkhorn) +plt.axis('off') +plt.title('SinkhornTransport') -pl.subplot(2, 3, 3) -pl.imshow(Image_mapping_linear) -pl.axis('off') -pl.title('MappingTransport (linear)') +plt.subplot(2, 3, 3) +plt.imshow(Image_mapping_linear) +plt.axis('off') +plt.title('MappingTransport (linear)') -pl.subplot(2, 3, 6) -pl.imshow(Image_mapping_gaussian) -pl.axis('off') -pl.title('MappingTransport (gaussian)') -pl.tight_layout() +plt.subplot(2, 3, 6) +plt.imshow(Image_mapping_gaussian) +plt.axis('off') +plt.title('MappingTransport (gaussian)') +plt.tight_layout() -pl.show() +plt.show() diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py index e2d88baf3..af0b8f1a8 100755 --- a/examples/gromov/plot_gromov_barycenter.py +++ b/examples/gromov/plot_gromov_barycenter.py @@ -13,11 +13,13 @@ # # License: MIT License +import os +from pathlib import Path import numpy as np import scipy as sp -import matplotlib.pylab as pl +from matplotlib import pyplot as plt from sklearn import manifold from sklearn.decomposition import PCA @@ -89,17 +91,18 @@ def im2mat(img): return img.reshape((img.shape[0] * img.shape[1], img.shape[2])) -square = pl.imread('../../data/square.png').astype(np.float64)[:, :, 2] -cross = pl.imread('../../data/cross.png').astype(np.float64)[:, :, 2] -triangle = pl.imread('../../data/triangle.png').astype(np.float64)[:, :, 2] -star = pl.imread('../../data/star.png').astype(np.float64)[:, :, 2] +data_path = os.path.join(Path(__file__).parent.parent.parent, 'data') + +square = plt.imread(os.path.join(data_path, 'square.png')).astype(np.float64)[:, :, 2] +cross = plt.imread(os.path.join(data_path, 'cross.png')).astype(np.float64)[:, :, 2] +triangle = plt.imread(os.path.join(data_path, 'triangle.png')).astype(np.float64)[:, :, 2] +star = plt.imread(os.path.join(data_path, 'star.png')).astype(np.float64)[:, :, 2] shapes = [square, cross, triangle, star] S = 4 xs = [[] for i in range(S)] - for nb in range(4): for i in range(8): for j in range(8): @@ -184,64 +187,64 @@ def im2mat(img): npost23 = [clf.fit_transform(npost23[s]) for s in range(2)] -fig = pl.figure(figsize=(10, 10)) +fig = plt.figure(figsize=(10, 10)) -ax1 = pl.subplot2grid((4, 4), (0, 0)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax1 = plt.subplot2grid((4, 4), (0, 0)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r') -ax2 = pl.subplot2grid((4, 4), (0, 1)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax2 = plt.subplot2grid((4, 4), (0, 1)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b') -ax3 = pl.subplot2grid((4, 4), (0, 2)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax3 = plt.subplot2grid((4, 4), (0, 2)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b') -ax4 = pl.subplot2grid((4, 4), (0, 3)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax4 = plt.subplot2grid((4, 4), (0, 3)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r') -ax5 = pl.subplot2grid((4, 4), (1, 0)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax5 = plt.subplot2grid((4, 4), (1, 0)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b') -ax6 = pl.subplot2grid((4, 4), (1, 3)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax6 = plt.subplot2grid((4, 4), (1, 3)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b') -ax7 = pl.subplot2grid((4, 4), (2, 0)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax7 = plt.subplot2grid((4, 4), (2, 0)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b') -ax8 = pl.subplot2grid((4, 4), (2, 3)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax8 = plt.subplot2grid((4, 4), (2, 3)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b') -ax9 = pl.subplot2grid((4, 4), (3, 0)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax9 = plt.subplot2grid((4, 4), (3, 0)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r') -ax10 = pl.subplot2grid((4, 4), (3, 1)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax10 = plt.subplot2grid((4, 4), (3, 1)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b') -ax11 = pl.subplot2grid((4, 4), (3, 2)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax11 = plt.subplot2grid((4, 4), (3, 2)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b') -ax12 = pl.subplot2grid((4, 4), (3, 3)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax12 = plt.subplot2grid((4, 4), (3, 3)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r') From 3a6d2a9a60f5322edda093bed9a3f0ac5bd6fc25 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 27 Oct 2021 12:28:13 +0200 Subject: [PATCH 08/25] fix stopping criterion debiased --- ot/bregman.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 2b8e995e7..e5a968243 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1322,20 +1322,25 @@ def _barycenter_sinkhorn_debiased(A, M, reg, weights=None, numItermax=1000, u = (geometricMean(UKv) / UKv.T).T c = nx.ones(A.shape[0], type_as=A) + bar = nx.ones(A.shape[0], type_as=A) + for ii in range(numItermax): + bold = bar UKv = nx.dot(K, A / nx.dot(K, u)) bar = c * geometricBar(weights, UKv) u = bar[:, None] / UKv c = (c * bar / nx.dot(K, c)) ** 0.5 - if ii % 10 == 1: - err = nx.sum(nx.std(UKv, axis=1)) + if ii % 10 == 9: + err = abs(bar - bold).max() / max(bar.max(), 1.) # log and verbose print if log: log['err'].append(err) - if err < stopThr: + # debiased Sinkhorn does not converge monotonically + # guarantee a few iterations are done before stopping + if err < stopThr and ii > 20: break if verbose: if ii % 200 == 0: From e2ac99eb1e26d730975d6226354772e66b177744 Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Wed, 27 Oct 2021 13:56:18 +0200 Subject: [PATCH 09/25] pass alex --- .../plot_convolutional_barycenter.py | 39 ++++++++----------- .../barycenters/plot_debiased_barycenter.py | 27 ++++++------- ot/bregman.py | 4 -- 3 files changed, 28 insertions(+), 42 deletions(-) diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py index 1f32adfad..9f39adb35 100644 --- a/examples/barycenters/plot_convolutional_barycenter.py +++ b/examples/barycenters/plot_convolutional_barycenter.py @@ -6,8 +6,8 @@ Convolutional Wasserstein Barycenter example ============================================ -This example is designed to illustrate how the Convolutional Wasserstein Barycenter -function of POT works. +This example is designed to illustrate how the Convolutional Wasserstein +Barycenter function of POT works. """ # Author: Nicolas Courty @@ -17,7 +17,7 @@ from pathlib import Path import numpy as np -from matplotlib import pyplot as plt +import matplotlib.pyplot as plt import ot ############################################################################## @@ -33,17 +33,11 @@ f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2] f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] - -A = [] f1 = f1 / np.sum(f1) f2 = f2 / np.sum(f2) f3 = f3 / np.sum(f3) f4 = f4 / np.sum(f4) -A.append(f1) -A.append(f2) -A.append(f3) -A.append(f4) -A = np.array(A) +A = np.array([f1, f2, f3, f4]) nb_images = 5 @@ -60,14 +54,13 @@ # ---------------------------------------- # -plt.figure(figsize=(10, 10)) -plt.title('Convolutional Wasserstein Barycenters in POT') +fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7)) +plt.suptitle('Convolutional Wasserstein Barycenters in POT') cm = 'Blues' # regularization parameter reg = 0.004 for i in range(nb_images): for j in range(nb_images): - plt.subplot(nb_images, nb_images, i * nb_images + j + 1) tx = float(i) / (nb_images - 1) ty = float(j) / (nb_images - 1) @@ -77,19 +70,19 @@ weights = (1 - ty) * tmp1 + ty * tmp2 if i == 0 and j == 0: - plt.imshow(f1, cmap=cm) - plt.axis('off') + axes[i, j].imshow(f1, cmap=cm) elif i == 0 and j == (nb_images - 1): - plt.imshow(f3, cmap=cm) - plt.axis('off') + axes[i, j].imshow(f3, cmap=cm) elif i == (nb_images - 1) and j == 0: - plt.imshow(f2, cmap=cm) - plt.axis('off') + axes[i, j].imshow(f2, cmap=cm) elif i == (nb_images - 1) and j == (nb_images - 1): - plt.imshow(f4, cmap=cm) - plt.axis('off') + axes[i, j].imshow(f4, cmap=cm) else: # call to barycenter computation - plt.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm) - plt.axis('off') + axes[i, j].imshow( + ot.bregman.convolutional_barycenter2d(A, reg, weights), + cmap=cm + ) + axes[i, j].axis('off') +plt.tight_layout() plt.show() diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py index 6f9eef452..66fdaf847 100644 --- a/examples/barycenters/plot_debiased_barycenter.py +++ b/examples/barycenters/plot_debiased_barycenter.py @@ -1,16 +1,15 @@ # -*- coding: utf-8 -*- """ -============================== +================================= Debiased Sinkhorn barycenter demo -============================== +================================= This example illustrates the computation of the debiased Sinkhorn barycenter -as proposed in [28]. +as proposed in [28]_. -[28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th - International Conference on Machine Learning, PMLR 119:4692-4701, 2020 - +.. [28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th + International Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ # Author: Hicham Janati @@ -22,7 +21,8 @@ from pathlib import Path import numpy as np -import matplotlib.pylab as plt +import matplotlib.pyplot as plt + import ot from ot.bregman import barycenter, convolutional_barycenter2d @@ -81,24 +81,20 @@ f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2] f3 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] -A = [] f1 = f1 / np.sum(f1) f2 = f2 / np.sum(f2) f3 = f3 / np.sum(f3) -A.append(f1) -A.append(f2) -A.append(f3) - -A = np.array(A) +A = np.array([f1, f2, f3]) ############################################################################## # Display the input images -f, axes = plt.subplots(1, 3, figsize=(12, 4)) +fig, axes = plt.subplots(1, 3, figsize=(7, 4)) for ax, img in zip(axes, A): ax.imshow(img, cmap="Greys") ax.axis("off") +fig.tight_layout() plt.show() @@ -117,7 +113,7 @@ titles = ["Sinkhorn", "Debiased"] all_bars = [bars_sinkhorn, bars_debiased] -f, axes = plt.subplots(2, 3, figsize=(12, 8)) +fig, axes = plt.subplots(2, 3, figsize=(8, 6)) for jj, (method, ax_row, bars) in enumerate(zip(titles, axes, all_bars)): for ii, (ax, img, eps) in enumerate(zip(ax_row, bars, epsilons)): ax.imshow(img, cmap="Greys") @@ -131,4 +127,5 @@ ax.spines['left'].set_visible(False) if ii == 0: ax.set_ylabel(method, fontsize=15) +fig.tight_layout() plt.show() diff --git a/ot/bregman.py b/ot/bregman.py index b84db065f..a11b1dd49 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1852,8 +1852,6 @@ def _convolutional_barycenter2d_sinkhorn(A, reg, weights=None, numItermax=10000, A. & Guibas, L. (2015). Convolutional wasserstein distances: Efficient optimal transportation on geometric domains. ACM Transactions on Graphics (TOG), 34(4), 66 - - """ A = list_to_array(A) @@ -2124,7 +2122,6 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016. - """ a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0) @@ -2248,7 +2245,6 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia "Optimal transport for multi-source domain adaptation under target shift", International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. - ''' Xs = list_to_array(*Xs) From 366ff62cbf8f31aa35ce8df2eecd5d9a9599d39f Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Fri, 29 Oct 2021 23:16:58 +0200 Subject: [PATCH 10/25] change params with new API --- examples/barycenters/plot_barycenter_1D.py | 59 +++++++++---------- .../barycenters/plot_debiased_barycenter.py | 12 ++-- 2 files changed, 34 insertions(+), 37 deletions(-) diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py index 63dc4603f..00bcfe67c 100644 --- a/examples/barycenters/plot_barycenter_1D.py +++ b/examples/barycenters/plot_barycenter_1D.py @@ -21,7 +21,7 @@ # sphinx_gallery_thumbnail_number = 4 import numpy as np -import matplotlib.pylab as pl +import matplotlib.pyplot as plt import ot # necessary for 3d plot even if not used from mpl_toolkits.mplot3d import Axes3D # noqa @@ -56,11 +56,11 @@ #%% plot the distributions -pl.figure(1, figsize=(6.4, 3)) -for i in range(n_distributions): - pl.plot(x, A[:, i]) -pl.title('Distributions') -pl.tight_layout() +# plt.figure(1, figsize=(6.4, 3)) +# for i in range(n_distributions): +# plt.plot(x, A[:, i]) +# plt.title('Distributions') +# plt.tight_layout() ############################################################################## # Barycenter computation @@ -78,24 +78,20 @@ reg = 1e-3 bary_wass = ot.bregman.barycenter(A, M, reg, weights) -pl.figure(2) -pl.clf() -pl.subplot(2, 1, 1) -for i in range(n_distributions): - pl.plot(x, A[:, i]) -pl.title('Distributions') +f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True) +ax1.plot(x, A, color="black") +ax1.set_title('Distributions') -pl.subplot(2, 1, 2) -pl.plot(x, bary_l2, 'r', label='l2') -pl.plot(x, bary_wass, 'g', label='Wasserstein') -pl.legend() -pl.title('Barycenters') -pl.tight_layout() +ax2.plot(x, bary_l2, 'r', label='l2') +ax2.plot(x, bary_wass, 'g', label='Wasserstein') +ax2.set_title('Barycenters') + +plt.legend() +plt.show() ############################################################################## # Barycentric interpolation # ------------------------- - #%% barycenter interpolation n_alpha = 11 @@ -106,24 +102,23 @@ B_wass = np.copy(B_l2) -for i in range(0, n_alpha): +for i in range(n_alpha): alpha = alpha_list[i] weights = np.array([1 - alpha, alpha]) B_l2[:, i] = A.dot(weights) B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights) #%% plot interpolation +plt.figure() -pl.figure(3) - -cmap = pl.cm.get_cmap('viridis') +cmap = plt.cm.get_cmap('viridis') verts = [] zs = alpha_list for i, z in enumerate(zs): ys = B_l2[:, i] verts.append(list(zip(x, ys))) -ax = pl.gcf().gca(projection='3d') +ax = plt.gcf().gca(projection='3d') poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) @@ -134,18 +129,18 @@ ax.set_ylim3d(0, 1) ax.set_zlabel('') ax.set_zlim3d(0, B_l2.max() * 1.01) -pl.title('Barycenter interpolation with l2') -pl.tight_layout() +plt.title('Barycenter interpolation with l2') +plt.tight_layout() -pl.figure(4) -cmap = pl.cm.get_cmap('viridis') +plt.figure(4) +cmap = plt.cm.get_cmap('viridis') verts = [] zs = alpha_list for i, z in enumerate(zs): ys = B_wass[:, i] verts.append(list(zip(x, ys))) -ax = pl.gcf().gca(projection='3d') +ax = plt.gcf().gca(projection='3d') poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) @@ -156,7 +151,7 @@ ax.set_ylim3d(0, 1) ax.set_zlabel('') ax.set_zlim3d(0, B_l2.max() * 1.01) -pl.title('Barycenter interpolation with Wasserstein') -pl.tight_layout() +plt.title('Barycenter interpolation with Wasserstein') +plt.tight_layout() -pl.show() +plt.show() diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py index 66fdaf847..f9206b235 100644 --- a/examples/barycenters/plot_debiased_barycenter.py +++ b/examples/barycenters/plot_debiased_barycenter.py @@ -24,7 +24,9 @@ import matplotlib.pyplot as plt import ot -from ot.bregman import barycenter, convolutional_barycenter2d +from ot.bregman import (barycenter, barycenter_debiased, + convolutional_barycenter2d, + convolutional_barycenter2d_debiased) ############################################################################## # Debiased barycenter of 1D Gaussians @@ -57,8 +59,8 @@ epsilons = [5e-3, 1e-2, 5e-2] -bars = [barycenter(A, M, reg, weights, method="sinkhorn") for reg in epsilons] -bars_debiased = [barycenter(A, M, reg, weights, method="debiased") for reg in epsilons] +bars = [barycenter(A, M, reg, weights) for reg in epsilons] +bars_debiased = [barycenter_debiased(A, M, reg, weights) for reg in epsilons] labels = ["Sinkhorn barycenter", "Debiased barycenter"] colors = ["indianred", "gold"] @@ -106,8 +108,8 @@ bars_sinkhorn, bars_debiased = [], [] epsilons = [5e-3, 7e-3, 1e-2] for eps in epsilons: - bar = convolutional_barycenter2d(A, eps, method="sinkhorn") - bar_debiased = convolutional_barycenter2d(A, eps, method="debiased") + bar = convolutional_barycenter2d(A, eps) + bar_debiased = convolutional_barycenter2d_debiased(A, eps) bars_sinkhorn.append(bar) bars_debiased.append(bar_debiased) From 5297495058d662a5f4eee779c5b33dd20343cb30 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Fri, 29 Oct 2021 23:17:15 +0200 Subject: [PATCH 11/25] add logdomain barycenters + separate debiased API --- ot/bregman.py | 605 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 429 insertions(+), 176 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index a11b1dd49..7a9b7b728 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -488,7 +488,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, "increase the number of iterations `numItermax` " "or the regularization parameter `reg`.") if log: - log['n_iter'] = ii + log['niter'] = ii log['u'] = u log['v'] = v @@ -639,7 +639,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, if log: log = {'err': []} - Mr = M / (-reg) + Mr = - M / reg # we assume that no distances are null except those of the diagonal of # distances @@ -656,14 +656,13 @@ def get_logT(u, v): loga = nx.log(a) logb = nx.log(b) - cpt = 0 err = 1 - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): v = logb - nx.logsumexp(Mr + u[:, None], 0) u = loga - nx.logsumexp(Mr + v[None, :], 1) - if cpt % 10 == 0: + if ii % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations @@ -674,13 +673,19 @@ def get_logT(u, v): log['err'].append(err) if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt + 1 + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + else: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: + log['niter'] = ii log['log_u'] = u log['log_v'] = v log['u'] = nx.exp(u) @@ -1240,7 +1245,7 @@ def get_reg(n): # exponential decreasing log['alpha'] = alpha log['beta'] = beta log['warmstart'] = (log['alpha'], log['beta']) - log['n_iter'] = ii + log['niter'] = ii return G, log else: return G @@ -1288,9 +1293,7 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, - :math:`OT_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - if `method` is `sinkhorn` or `sinkhorn_stabilized`. If `method`is `debiased`, - :math:`OT_{reg}(\cdot,\cdot)` is the entropic - sinkhorn divergence (see :py:func:`ot.bregman.empirical_sinkhorn_divergence`) + if `method` is `sinkhorn` or `sinkhorn_stabilized` or `sinkhorn_log`. - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - `reg` and :math:`\mathbf{M}` are respectively the regularization term and @@ -1308,7 +1311,7 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, reg : float Regularization term > 0 method : str (optional) - method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'debiased' + method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'sinkhorn_log' weights : array-like, shape (n_hists,) Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) numItermax : int, optional @@ -1349,11 +1352,11 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - elif method.lower() == 'debiased': - return _barycenter_sinkhorn_debiased(A, M, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + elif method.lower() == 'sinkhorn_log': + return _barycenter_sinkhorn_log(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) @@ -1467,119 +1470,61 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, return geometricBar(weights, UKv) -def _barycenter_sinkhorn_debiased(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A - - The function solves the following optimization problem: - - .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) - - where : - - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein - distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix - :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term - and the cost matrix for OT - - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in :ref:`[3] ` - - Parameters - ---------- - A : array-like, shape (dim, n_hists) - `n_hists` training distributions :math:`a_i` of size `dim` - M : array-like, shape (dim, dim) - loss matrix for OT - reg : float - Regularization term > 0 - weights : array-like, shape (n_hists,) - Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - - - Returns - ------- - a : (dim,) array-like - Wasserstein barycenter - log : dict - log dictionary return only if log==True in parameters - - - .. _references-barycenter-sinkhorn: - References - ---------- - - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). - Iterative Bregman projections for regularized transportation problems. - SIAM Journal on Scientific Computing, 37(2), A1111-A1138. - +def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False): + r"""Compute the entropic wasserstein barycenter in log-domain """ A, M = list_to_array(A, M) + dim, n_hists = A.shape nx = get_backend(A, M) if weights is None: - weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] + weights = nx.ones(n_hists, type_as=A) / n_hists else: assert (len(weights) == A.shape[1]) if log: log = {'err': []} - K = nx.exp(-M / reg) - + M = - M / reg + logA = nx.log(A + 1e-15) + log_KU, G = nx.zeros((2, *logA.shape), type_as=A) err = 1 - - UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) - - u = (geometricMean(UKv) / UKv.T).T - c = nx.ones(A.shape[0], type_as=A) - bar = nx.ones(A.shape[0], type_as=A) - for ii in range(numItermax): - bold = bar - UKv = nx.dot(K, A / nx.dot(K, u)) - bar = c * geometricBar(weights, UKv) - u = bar[:, None] / UKv - c = (c * bar / nx.dot(K, c)) ** 0.5 + log_bar = nx.zeros(dim, type_as=A) + for k in range(n_hists): + f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) + log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) + log_bar += weights[k] * log_KU[:, k] - if ii % 10 == 9: - err = abs(bar - bold).max() / max(bar.max(), 1.) + if ii % 10 == 1: + err = nx.exp(G + log_KU).std(axis=1).sum() # log and verbose print if log: log['err'].append(err) - # debiased Sinkhorn does not converge monotonically - # guarantee a few iterations are done before stopping - if err < stopThr and ii > 20: + if err < stopThr: break if verbose: if ii % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) + + G = log_bar[:, None] - log_KU + else: warnings.warn("Sinkhorn did not converge. You might want to " "increase the number of iterations `numItermax` " "or the regularization parameter `reg`.") if log: log['niter'] = ii - return bar, log + return nx.exp(log_bar), log else: - return bar + return nx.exp(log_bar) def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, @@ -1697,13 +1642,15 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, if log: log['err'].append(err) if err < stopThr: + break + if verbose: if ii % 50 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) else: - warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + + warnings.warn("Stabilized Sinkhorn did not converge." + "Try a larger entropy `reg`" + "Or a larger absorption threshold `tau`.") if log: @@ -1715,89 +1662,201 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, return q -def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-5, verbose=False, log=False, **kwargs): - r"""Compute the entropic regularized wasserstein barycenter of distributions A - where A is a collection of 2D images. +def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-4, verbose=False, log=False, **kwargs): + r"""Compute the debiased Sinkhorn barycenter of distributions A The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i OT_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = arg\min_\mathbf{a} \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) where : - - :math:`OT_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein - distance (see :py:func:`ot.bregman.sinkhorn`) - if `method` is `sinkhorn` or `sinkhorn_stabilized`. If `method`is `debiased`, - :math:`OT_{reg}(\cdot,\cdot)` is the entropic - sinkhorn divergence (see :py:func:`ot.bregman.empirical_sinkhorn_divergence`) - - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions - of matrix :math:`\mathbf{A}` - - `reg` is the regularization strength scalar value + - :math:`S_{reg}(\cdot,\cdot)` is the debiased Sinkhorn divergence + (see :py:func:`ot.bregman.emirical_sinkhorn_divergence`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm - as proposed in :ref:`[21] ` + The algorithm used for solving the problem is the debiased Sinkhorn + algorithm as proposed in :ref:`[28] ` Parameters ---------- - A : array-like, shape (n_hists, width, height) - `n` distributions (2D images) of size `width` x `height` + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`a_i` of size `dim` + M : array-like, shape (dim, dim) + loss matrix for OT reg : float - Regularization term >0 + Regularization term > 0 + method : str (optional) + method used for the solver either 'sinkhorn' or 'sinkhorn_log' weights : array-like, shape (n_hists,) - Weights of each image on the simplex (barycentric coodinates) - method : string, optional - method used for the solver either 'sinkhorn' or 'debiased' + Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshold on error (> 0) - stabThr : float, optional - Stabilization threshold to avoid numerical precision issue + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True + Returns ------- - a : array-like, shape (width, height) - 2D Wasserstein barycenter + a : (dim,) array-like + Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + .. _references-sinkhorn-debiased: + References + ---------- - .. _references-convolutional-barycenter-2d: - References - ---------- - - .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, - A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: - Efficient optimal transportation on geometric domains. ACM Transactions - on Graphics (TOG), 34(4), 66 - .. [28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th - International Conference on Machine Learning, PMLR 119:4692-4701, 2020 + .. [28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ if method.lower() == 'sinkhorn': - return _convolutional_barycenter2d_sinkhorn(A, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, - **kwargs) - elif method.lower() == 'debiased': - return _convolutional_barycenter2d_debiased(A, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return _barycenter_debiased(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return _barycenter_debiased_log(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) -def _convolutional_barycenter2d_sinkhorn(A, reg, weights=None, numItermax=10000, - stopThr=1e-9, stabThr=1e-30, verbose=False, - log=False): +def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False): + r"""Compute the debiased sinkhorn barycenter of distributions A. + """ + + A, M = list_to_array(A, M) + + nx = get_backend(A, M) + + if weights is None: + weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + K = nx.exp(-M / reg) + + err = 1 + + UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) + + u = (geometricMean(UKv) / UKv.T).T + c = nx.ones(A.shape[0], type_as=A) + bar = nx.ones(A.shape[0], type_as=A) + + for ii in range(numItermax): + bold = bar + UKv = nx.dot(K, A / nx.dot(K, u)) + bar = c * geometricBar(weights, UKv) + u = bar[:, None] / UKv + c = (c * bar / nx.dot(K, c)) ** 0.5 + + if ii % 10 == 9: + err = abs(bar - bold).max() / max(bar.max(), 1.) + + # log and verbose print + if log: + log['err'].append(err) + + # debiased Sinkhorn does not converge monotonically + # guarantee a few iterations are done before stopping + if err < stopThr and ii > 20: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + else: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return bar, log + else: + return bar + + +def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False): + r"""Compute the debiased sinkhorn barycenter in log domain. + """ + + A, M = list_to_array(A, M) + dim, n_hists = A.shape + + nx = get_backend(A, M) + + if weights is None: + weights = nx.ones(n_hists, type_as=A) / n_hists + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + M = - M / reg + logA = nx.log(A + 1e-15) + log_KU, G = nx.zeros((2, *logA.shape), type_as=A) + c = nx.zeros(dim, type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros(dim, type_as=A) + for k in range(n_hists): + f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) + log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) + log_bar += weights[k] * log_KU[:, k] + log_bar += c + if ii % 10 == 1: + err = nx.exp(G + log_KU).std(axis=1).sum() + + # log and verbose print + if log: + log['err'].append(err) + + if err < stopThr and ii > 20: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + G = log_bar[:, None] - log_KU + for _ in range(10): + c = 0.5 * (c + log_bar - nx.logsumexp(M + c[:, None], axis=0)) + + else: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + +def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-5, verbose=False, log=False, **kwargs): r"""Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images. @@ -1808,14 +1867,14 @@ def _convolutional_barycenter2d_sinkhorn(A, reg, weights=None, numItermax=10000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance - (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two - dimensions of matrix :math:`\mathbf{A}` + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions + of matrix :math:`\mathbf{A}` - `reg` is the regularization strength scalar value - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling - algorithm as proposed in :ref:`[21] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm + as proposed in :ref:`[21] ` Parameters ---------- @@ -1825,6 +1884,8 @@ def _convolutional_barycenter2d_sinkhorn(A, reg, weights=None, numItermax=10000, Regularization term >0 weights : array-like, shape (n_hists,) Weights of each image on the simplex (barycentric coodinates) + method : string, optional + method used for the solver either 'sinkhorn' or 'sinkhorn_log' numItermax : int, optional Max number of iterations stopThr : float, optional @@ -1848,10 +1909,34 @@ def _convolutional_barycenter2d_sinkhorn(A, reg, weights=None, numItermax=10000, References ---------- - .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, - A. & Guibas, L. (2015). Convolutional wasserstein distances: Efficient - optimal transportation on geometric domains. ACM Transactions on Graphics - (TOG), 34(4), 66 + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, + A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: + Efficient optimal transportation on geometric domains. ACM Transactions + on Graphics (TOG), 34(4), 66 + .. [28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th + International Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + + if method.lower() == 'sinkhorn': + return _convolutional_barycenter2d(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return _convolutional_barycenter2d_log(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, + stopThr=1e-9, stabThr=1e-30, verbose=False, + log=False): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. """ A = list_to_array(A) @@ -1887,16 +1972,15 @@ def convol_imgs(imgs): kxy = nx.einsum("...ij,klj->kli", K2, kx) return kxy - KV = convol_imgs(V) + KU = convol_imgs(U) for ii in range(numItermax): - bold = bar + V = bar[None] / KU + KV = convol_imgs(V) U = A / KV KU = convol_imgs(U) bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) - V = bar[None] / KU - KV = convol_imgs(V) if ii % 10 == 9: - err = abs(bold - bar).max() / max(1., bar.max()) + err = (V * KU).std(axis=0).sum() # log and verbose print if log: log['err'].append(err) @@ -1907,6 +1991,7 @@ def convol_imgs(imgs): print('{:5d}|{:8e}|'.format(ii, err)) if err < stopThr: break + else: warnings.warn("Convolutional Sinkhorn did not converge. " "Try a larger number of iterations `numItermax` " @@ -1919,9 +2004,80 @@ def convol_imgs(imgs): return bar -def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, - stopThr=1e-4, stabThr=1e-15, verbose=False, - log=False): +def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, + stopThr=1e-4, stabThr=1e-30, verbose=False, + log=False): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images in log-domain. + """ + + A = list_to_array(A) + + nx = get_backend(A) + + n_hists, width, height = A.shape + + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == n_hists) + + if log: + log = {'err': []} + + err = 1 + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width) + [Y, X] = nx.meshgrid(t, t) + M1 = - (X - Y) ** 2 / reg + + t = nx.linspace(0, 1, height) + [Y, X] = nx.meshgrid(t, t) + M2 = - (X - Y) ** 2 / reg + + def convol_img(log_img): + log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) + log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T + return log_img + + logA = nx.log(A + stabThr) + log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros((width, height), type_as=A) + for k in range(n_hists): + f = logA[k] - convol_img(G[k]) + log_KU[k] = convol_img(f) + log_bar += weights[k] * log_KU[k] + + if ii % 10 == 9: + err = nx.exp(G + log_KU).std(axis=0).sum() + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + G = log_bar[None, :, :] - log_KU + + else: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + +def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-5, verbose=False, log=False, **kwargs): r"""Compute the debiased sinkhorn barycenter of distributions A where A is a collection of 2D images. @@ -1949,6 +2105,8 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, Regularization term >0 weights : array-like, shape (n_hists,) Weights of each image on the simplex (barycentric coodinates) + method : string, optional + method used for the solver either 'sinkhorn' or 'sinkhorn_log' numItermax : int, optional Max number of iterations stopThr : float, optional @@ -1976,20 +2134,42 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ + if method.lower() == 'sinkhorn': + return _convolutional_barycenter2d_debiased(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return _convolutional_barycenter2d_debiased_log(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, + stopThr=1e-4, stabThr=1e-15, verbose=False, + log=False): + r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions. + """ + A = list_to_array(A) + n_hists, width, height = A.shape nx = get_backend(A) if weights is None: - weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0] + weights = nx.ones((n_hists,), type_as=A) / n_hists else: - assert (len(weights) == A.shape[0]) + assert (len(weights) == n_hists) if log: log = {'err': []} - bar = nx.ones(A.shape[1:], type_as=A) - bar /= bar.sum() + bar = nx.ones((width, height), type_as=A) + bar /= width * height U = nx.ones(A.shape, type_as=A) V = nx.ones(A.shape, type_as=A) c = nx.ones(A.shape[1:], type_as=A) @@ -1997,11 +2177,11 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, # build the convolution operator # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, A.shape[1]) + t = nx.linspace(0, 1, width) [Y, X] = nx.meshgrid(t, t) K1 = nx.exp(-(X - Y) ** 2 / reg) - t = nx.linspace(0, 1, A.shape[2]) + t = nx.linspace(0, 1, height) [Y, X] = nx.meshgrid(t, t) K2 = nx.exp(-(X - Y) ** 2 / reg) @@ -2010,19 +2190,19 @@ def convol_imgs(imgs): kxy = nx.einsum("...ij,klj->kli", K2, kx) return kxy - KV = convol_imgs(V) + KU = convol_imgs(U) for ii in range(numItermax): - bold = bar + V = bar[None] / KU + KV = convol_imgs(V) U = A / KV KU = convol_imgs(U) bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) - V = bar[None] / KU - KV = convol_imgs(V) + for _ in range(10): c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5 if ii % 10 == 9: - err = abs(bold - bar).max() / max(1., bar.max()) + err = (V * KU).std(axis=0).sum() # log and verbose print if log: log['err'].append(err) @@ -2048,6 +2228,79 @@ def convol_imgs(imgs): return bar +def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000, + stopThr=1e-4, stabThr=1e-30, verbose=False, + log=False): + r"""Compute the debiased barycenter of 2D images in log-domain. + """ + + A = list_to_array(A) + n_hists, width, height = A.shape + nx = get_backend(A) + + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == A.shape[0]) + + if log: + log = {'err': []} + + err = 1 + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width) + [Y, X] = nx.meshgrid(t, t) + M1 = - (X - Y) ** 2 / reg + + t = nx.linspace(0, 1, height) + [Y, X] = nx.meshgrid(t, t) + M2 = - (X - Y) ** 2 / reg + + def convol_img(log_img): + log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) + log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T + return log_img + + logA = nx.log(A + stabThr) + log_bar, c = nx.zeros((2, width, height), type_as=A) + log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros((width, height), type_as=A) + for k in range(n_hists): + f = logA[k] - convol_img(G[k]) + log_KU[k] = convol_img(f) + log_bar += weights[k] * log_KU[k] + log_bar += c + for _ in range(10): + c = 0.5 * (c + log_bar - convol_img(c)) + + if ii % 10 == 9: + err = nx.exp(G + log_KU).std(axis=0).sum() + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr and ii > 20: + break + G = log_bar[None, :, :] - log_KU + + else: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False): r""" From 6a38c03c1b249ace9bdb478192c5fe1d6da390eb Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Fri, 29 Oct 2021 23:17:27 +0200 Subject: [PATCH 12/25] test new API --- test/test_bregman.py | 74 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 68 insertions(+), 6 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 47fd5105c..a098c0a2f 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -286,7 +286,7 @@ def test_sinkhorn_variants_log_multib(): np.testing.assert_allclose(G0, Gl, atol=1e-05) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "debiased"]) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]) def test_barycenter(nx, method): n_bins = 100 # nb bins @@ -317,7 +317,41 @@ def test_barycenter(nx, method): np.testing.assert_allclose(1, np.sum(bary_wass)) np.testing.assert_allclose(bary_wass, bary_wass_np) - ot.bregman.barycenter(Ab, Mb, reg, log=True, verbose=True) + ot.bregman.barycenter(Ab, Mb, reg, log=True) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_barycenter_debiased(nx, method): + n_bins = 100 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + + Ab = nx.from_numpy(A) + Mb = nx.from_numpy(M) + weightsb = nx.from_numpy(weights) + + # wasserstein + reg = 1e-2 + bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method) + bary_wass, _ = ot.bregman.barycenter_debiased(Ab, Mb, reg, weightsb, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) + + ot.bregman.barycenter_debiased(Ab, Mb, reg, log=True, verbose=False) def test_barycenter_stabilization(nx): @@ -356,13 +390,13 @@ def test_barycenter_stabilization(nx): np.testing.assert_allclose(bar, bar_np) -@pytest.mark.parametrize("method", ["sinkhorn", "debiased"]) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) def test_wasserstein_bary_2d(nx, method): - size = 100 # size of a square image - a1 = np.random.randn(size, size) + size = 20 # size of a square image + a1 = np.random.rand(size, size) a1 += a1.min() a1 = a1 / np.sum(a1) - a2 = np.random.randn(size, size) + a2 = np.random.rand(size, size) a2 += a2.min() a2 = a2 / np.sum(a2) # creating matrix A containing all distributions @@ -380,6 +414,34 @@ def test_wasserstein_bary_2d(nx, method): np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + # help in checking if log and verbose do not bug the function + # ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d_debiased(nx, method): + size = 20 # size of a square image + a1 = np.random.rand(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) + a2 = np.random.rand(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) + # creating matrix A containing all distributions + A = np.zeros((2, size, size)) + A[0, :, :] = a1 + A[1, :, :] = a2 + + Ab = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + # help in checking if log and verbose do not bug the function ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) From 17e58e7d9804cf747dac44e70b8b12309f2e6354 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Sat, 30 Oct 2021 13:45:32 +0200 Subject: [PATCH 13/25] fix jax read-only ? --- ot/bregman.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 7a9b7b728..07d395f37 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1497,7 +1497,7 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, for k in range(n_hists): f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) - log_bar += weights[k] * log_KU[:, k] + log_bar = log_bar + weights[k] * log_KU[:, k] if ii % 10 == 1: err = nx.exp(G + log_KU).std(axis=1).sum() @@ -2049,7 +2049,7 @@ def convol_img(log_img): for k in range(n_hists): f = logA[k] - convol_img(G[k]) log_KU[k] = convol_img(f) - log_bar += weights[k] * log_KU[k] + log_bar = log_bar + weights[k] * log_KU[k] if ii % 10 == 9: err = nx.exp(G + log_KU).std(axis=0).sum() @@ -2271,7 +2271,7 @@ def convol_img(log_img): for k in range(n_hists): f = logA[k] - convol_img(G[k]) log_KU[k] = convol_img(f) - log_bar += weights[k] * log_KU[k] + log_bar = log_bar + weights[k] * log_KU[k] log_bar += c for _ in range(10): c = 0.5 * (c + log_bar - convol_img(c)) From 62cd9c9f906b3b0617973763ceaa870fa471d812 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Mon, 1 Nov 2021 18:09:06 +0100 Subject: [PATCH 14/25] raise error for jax --- ot/bregman.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/ot/bregman.py b/ot/bregman.py index 07d395f37..b2c8d7b4c 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1480,6 +1480,10 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, nx = get_backend(A, M) + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists else: @@ -1804,6 +1808,9 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, dim, n_hists = A.shape nx = get_backend(A, M) + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists @@ -2014,6 +2021,9 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, A = list_to_array(A) nx = get_backend(A) + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") n_hists, width, height = A.shape @@ -2237,7 +2247,9 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10 A = list_to_array(A) n_hists, width, height = A.shape nx = get_backend(A) - + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") if weights is None: weights = nx.ones((n_hists,), type_as=A) / n_hists else: From 13d3575929027d9ed9c2c91f95e13b8381704b50 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Mon, 1 Nov 2021 18:09:16 +0100 Subject: [PATCH 15/25] test catch jax error --- test/test_bregman.py | 73 +++++++++++++++++++++++++++----------------- 1 file changed, 45 insertions(+), 28 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index a098c0a2f..6c6de9d15 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -7,6 +7,7 @@ # License: MIT License import numpy as np +from numpy.testing import assert_raises import pytest import ot @@ -307,17 +308,21 @@ def test_barycenter(nx, method): Ab = nx.from_numpy(A) Mb = nx.from_numpy(M) weightsb = nx.from_numpy(weights) - - # wasserstein reg = 1e-2 - bary_wass_np, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True) - bary_wass, _ = ot.bregman.barycenter(Ab, Mb, reg, weightsb, method=method, log=True) - bary_wass = nx.to_numpy(bary_wass) - np.testing.assert_allclose(1, np.sum(bary_wass)) - np.testing.assert_allclose(bary_wass, bary_wass_np) + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter(A, M, reg, weights, method=method) + else: + # wasserstein + bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method) + bary_wass, _ = ot.bregman.barycenter(Ab, Mb, reg, weightsb, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) - ot.bregman.barycenter(Ab, Mb, reg, log=True) + ot.bregman.barycenter(Ab, Mb, reg, log=True) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) @@ -344,14 +349,18 @@ def test_barycenter_debiased(nx, method): # wasserstein reg = 1e-2 - bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method) - bary_wass, _ = ot.bregman.barycenter_debiased(Ab, Mb, reg, weightsb, method=method, log=True) - bary_wass = nx.to_numpy(bary_wass) + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter_debiased(A, M, reg, weights, method=method) + else: + bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method) + bary_wass, _ = ot.bregman.barycenter_debiased(Ab, Mb, reg, weightsb, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) - np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) + np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) - ot.bregman.barycenter_debiased(Ab, Mb, reg, log=True, verbose=False) + ot.bregman.barycenter_debiased(Ab, Mb, reg, log=True, verbose=False) def test_barycenter_stabilization(nx): @@ -408,14 +417,18 @@ def test_wasserstein_bary_2d(nx, method): # wasserstein reg = 1e-2 - bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)) + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(A, reg, method=method) + else: + bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)) - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - # help in checking if log and verbose do not bug the function - # ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) @@ -436,14 +449,18 @@ def test_wasserstein_bary_2d_debiased(nx, method): # wasserstein reg = 1e-2 - bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)) - - np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) - np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) + else: + bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) def test_unmix(nx): From d8ae66fae5f1573187190e4d7c5e53cdfb0fc71c Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Mon, 1 Nov 2021 18:24:03 +0100 Subject: [PATCH 16/25] fix pytest catch error --- test/test_bregman.py | 122 +++++++++++++++++++++---------------------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 6c6de9d15..50a473272 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -68,9 +68,9 @@ def test_sinkhorn_backends(nx): G = ot.sinkhorn(a, a, M, 1) ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) - Gb = ot.sinkhorn(ab, ab, Mb, 1) + Gb = ot.sinkhorn(ab, ab, M_nx, 1) np.allclose(G, nx.to_numpy(Gb)) @@ -89,9 +89,9 @@ def test_sinkhorn2_backends(nx): G = ot.sinkhorn(a, a, M, 1) ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) - Gb = ot.sinkhorn2(ab, ab, Mb, 1) + Gb = ot.sinkhorn2(ab, ab, M_nx, 1) np.allclose(G, nx.to_numpy(Gb)) @@ -166,15 +166,15 @@ def test_sinkhorn_variants(nx): M = ot.dist(x, x) ub = nx.from_numpy(u) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) Ges = nx.to_numpy(ot.sinkhorn( - ub, ub, Mb, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) - G_green = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='greenkhorn', stopThr=1e-10)) + ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) + G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -200,12 +200,12 @@ def test_sinkhorn_variants_multi_b(nx): ub = nx.from_numpy(u) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -229,12 +229,12 @@ def test_sinkhorn2_variants_multi_b(nx): ub = nx.from_numpy(u) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -305,24 +305,24 @@ def test_barycenter(nx, method): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - Ab = nx.from_numpy(A) - Mb = nx.from_numpy(M) - weightsb = nx.from_numpy(weights) + A_nx = nx.from_numpy(A) + M_nx = nx.from_numpy(M) + weights_nx = nx.from_numpy(weights) reg = 1e-2 if nx.__name__ == "jax" and method == "sinkhorn_log": with pytest.raises(NotImplementedError): - ot.bregman.barycenter(A, M, reg, weights, method=method) + ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method) else: # wasserstein bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method) - bary_wass, _ = ot.bregman.barycenter(Ab, Mb, reg, weightsb, method=method, log=True) + bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass)) np.testing.assert_allclose(bary_wass, bary_wass_np) - ot.bregman.barycenter(Ab, Mb, reg, log=True) + ot.bregman.barycenter(A_nx, M_nx, reg, log=True) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) @@ -343,24 +343,24 @@ def test_barycenter_debiased(nx, method): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - Ab = nx.from_numpy(A) - Mb = nx.from_numpy(M) - weightsb = nx.from_numpy(weights) + A_nx = nx.from_numpy(A) + M_nx = nx.from_numpy(M) + weights_nx = nx.from_numpy(weights) # wasserstein reg = 1e-2 if nx.__name__ == "jax" and method == "sinkhorn_log": with pytest.raises(NotImplementedError): - ot.bregman.barycenter_debiased(A, M, reg, weights, method=method) + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) else: bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method) - bary_wass, _ = ot.bregman.barycenter_debiased(Ab, Mb, reg, weightsb, method=method, log=True) + bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) - ot.bregman.barycenter_debiased(Ab, Mb, reg, log=True, verbose=False) + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False) def test_barycenter_stabilization(nx): @@ -380,19 +380,19 @@ def test_barycenter_stabilization(nx): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - Ab = nx.from_numpy(A) - Mb = nx.from_numpy(M) + A_nx = nx.from_numpy(A) + M_nx = nx.from_numpy(M) weights_b = nx.from_numpy(weights) # wasserstein reg = 1e-2 bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) bar_stable = nx.to_numpy(ot.bregman.barycenter( - Ab, Mb, reg, weights_b, method="sinkhorn_stabilized", + A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized", stopThr=1e-8, verbose=True )) bar = nx.to_numpy(ot.bregman.barycenter( - Ab, Mb, reg, weights_b, method="sinkhorn", + A_nx, M_nx, reg, weights_b, method="sinkhorn", stopThr=1e-8, verbose=True )) np.testing.assert_allclose(bar, bar_stable) @@ -413,16 +413,16 @@ def test_wasserstein_bary_2d(nx, method): A[0, :, :] = a1 A[1, :, :] = a2 - Ab = nx.from_numpy(A) + A_nx = nx.from_numpy(A) # wasserstein reg = 1e-2 if nx.__name__ == "jax" and method == "sinkhorn_log": with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d(A, reg, method=method) + ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) else: bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) @@ -445,16 +445,16 @@ def test_wasserstein_bary_2d_debiased(nx, method): A[0, :, :] = a1 A[1, :, :] = a2 - Ab = nx.from_numpy(A) + A_nx = nx.from_numpy(A) # wasserstein reg = 1e-2 if nx.__name__ == "jax" and method == "sinkhorn_log": with pytest.raises(NotImplementedError): - ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) + ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) else: bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) @@ -485,20 +485,20 @@ def test_unmix(nx): ab = nx.from_numpy(a) Db = nx.from_numpy(D) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) M0b = nx.from_numpy(M0) h0b = nx.from_numpy(h0) # wasserstein reg = 1e-3 um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) - um = nx.to_numpy(ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, 1, alpha=0.01)) + um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) np.testing.assert_allclose(um, um_np) - ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, + ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01, log=True, verbose=True) @@ -517,22 +517,22 @@ def test_empirical_sinkhorn(nx): bb = nx.from_numpy(b) X_sb = nx.from_numpy(X_s) X_tb = nx.from_numpy(X_t) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) M_mb = nx.from_numpy(M_m, type_as=ab) G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1)) - sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1)) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True) G_log = nx.to_numpy(G_log) - sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean')) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) - loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1)) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints np.testing.assert_allclose( @@ -566,18 +566,18 @@ def test_lazy_empirical_sinkhorn(nx): bb = nx.from_numpy(b) X_sb = nx.from_numpy(X_s) X_tb = nx.from_numpy(X_t) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) M_mb = nx.from_numpy(M_m, type_as=ab) f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) - sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1)) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) - sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) @@ -587,7 +587,7 @@ def test_lazy_empirical_sinkhorn(nx): loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) - loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1)) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints np.testing.assert_allclose( @@ -621,13 +621,13 @@ def test_empirical_sinkhorn_divergence(nx): bb = nx.from_numpy(b) X_sb = nx.from_numpy(X_s) X_tb = nx.from_numpy(X_t) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) M_sb = nx.from_numpy(M_s, type_as=ab) M_tb = nx.from_numpy(M_t, type_as=ab) emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) sinkhorn_div = nx.to_numpy( - ot.sinkhorn2(ab, bb, Mb, 1) + ot.sinkhorn2(ab, bb, M_nx, 1) - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) ) @@ -660,14 +660,14 @@ def test_stabilized_vs_sinkhorn_multidim(nx): ab = nx.from_numpy(a) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) - G, log = ot.bregman.sinkhorn(ab, bb, Mb, reg=epsilon, + G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, method="sinkhorn_stabilized", log=True) G = nx.to_numpy(G) - G2, log2 = ot.bregman.sinkhorn(ab, bb, Mb, epsilon, + G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon, method="sinkhorn", log=True) G2 = nx.to_numpy(G2) @@ -722,14 +722,14 @@ def test_screenkhorn(nx): ab = nx.from_numpy(a) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) # np sinkhorn G_sink_np = ot.sinkhorn(a, b, M, 1e-03) # sinkhorn - G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1e-03)) + G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03)) # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, Mb, 1e-03, uniform=True, verbose=True)) + G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True)) # check marginals np.testing.assert_allclose(G_sink_np, G_sink) np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) @@ -739,10 +739,10 @@ def test_screenkhorn(nx): def test_convolutional_barycenter_non_square(nx): # test for image with height not equal width A = np.ones((2, 2, 3)) / (2 * 3) - Ab = nx.from_numpy(A) + A_nx = nx.from_numpy(A) b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03) - b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, 1e-03)) + b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, 1e-03)) np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) From 39fcf8392782fa02181f601fbd385432df47b3e7 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Mon, 1 Nov 2021 19:06:21 +0100 Subject: [PATCH 17/25] fix relative path --- examples/barycenters/plot_convolutional_barycenter.py | 3 ++- examples/barycenters/plot_debiased_barycenter.py | 3 ++- examples/domain-adaptation/plot_otda_color_images.py | 4 +++- examples/domain-adaptation/plot_otda_linear_mapping.py | 4 +++- examples/domain-adaptation/plot_otda_mapping_colors_images.py | 4 +++- examples/gromov/plot_gromov_barycenter.py | 3 ++- 6 files changed, 15 insertions(+), 6 deletions(-) diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py index 9f39adb35..3721f31d1 100644 --- a/examples/barycenters/plot_convolutional_barycenter.py +++ b/examples/barycenters/plot_convolutional_barycenter.py @@ -26,7 +26,8 @@ # # The four distributions are constructed from 4 simple images -data_path = os.path.join(Path(__file__).parent.parent.parent, 'data') +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2] f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2] diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py index f9206b235..6cfce7012 100644 --- a/examples/barycenters/plot_debiased_barycenter.py +++ b/examples/barycenters/plot_debiased_barycenter.py @@ -78,7 +78,8 @@ ############################################################################## # Debiased barycenter of 2D images # --------------------------------- -data_path = os.path.join(Path(__file__).parent.parent.parent, 'data') +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2] f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2] f3 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py index 84d0d7f4d..06dc8ab3d 100644 --- a/examples/domain-adaptation/plot_otda_color_images.py +++ b/examples/domain-adaptation/plot_otda_color_images.py @@ -49,7 +49,9 @@ def minmax(img): # ------------- # Loading images -data_path = os.path.join(Path(__file__).parent.parent.parent, 'data') +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') + I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py index df36fc00c..a44096a33 100644 --- a/examples/domain-adaptation/plot_otda_linear_mapping.py +++ b/examples/domain-adaptation/plot_otda_linear_mapping.py @@ -98,7 +98,9 @@ def minmax(img): # Loading images -data_path = os.path.join(Path(__file__).parent.parent.parent, 'data') +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') + I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py index d8b1ff36a..dbece7082 100644 --- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py +++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py @@ -50,7 +50,9 @@ def minmax(img): # ------------- # Loading images -data_path = os.path.join(Path(__file__).parent.parent.parent, 'data') +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') + I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py index af0b8f1a8..7fe081f8f 100755 --- a/examples/gromov/plot_gromov_barycenter.py +++ b/examples/gromov/plot_gromov_barycenter.py @@ -91,7 +91,8 @@ def im2mat(img): return img.reshape((img.shape[0] * img.shape[1], img.shape[2])) -data_path = os.path.join(Path(__file__).parent.parent.parent, 'data') +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') square = plt.imread(os.path.join(data_path, 'square.png')).astype(np.float64)[:, :, 2] cross = plt.imread(os.path.join(data_path, 'cross.png')).astype(np.float64)[:, :, 2] From b6cbc2f376b54e9317b04a8f6b97ecf003bfb585 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Mon, 1 Nov 2021 19:07:00 +0100 Subject: [PATCH 18/25] fix flake8 --- ot/bregman.py | 4 ++-- test/test_bregman.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index b2c8d7b4c..dc986eab8 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1863,7 +1863,7 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-5, verbose=False, log=False, **kwargs): + stopThr=1e-4, verbose=False, log=False, **kwargs): r"""Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images. @@ -2087,7 +2087,7 @@ def convol_img(log_img): def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-5, verbose=False, log=False, **kwargs): + stopThr=1e-4, verbose=False, log=False, **kwargs): r"""Compute the debiased sinkhorn barycenter of distributions A where A is a collection of 2D images. diff --git a/test/test_bregman.py b/test/test_bregman.py index 50a473272..f700cfc29 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -7,7 +7,6 @@ # License: MIT License import numpy as np -from numpy.testing import assert_raises import pytest import ot From 154e20389c9cd863fe860b9a4ea3e19e63095068 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 2 Nov 2021 16:29:18 +0100 Subject: [PATCH 19/25] add warn arg everywhere --- ot/bregman.py | 332 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 198 insertions(+), 134 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index ba655b775..641fcee9f 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -7,7 +7,7 @@ # Nicolas Courty # Kilian Fatras # Titouan Vayer -# Hicham Janati +# Hicham Janati # Mokhtar Z. Alaya # Alexander Tong # Ievgen Redko @@ -25,7 +25,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): + stopThr=1e-9, verbose=False, log=False, warn=True, + **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -97,6 +98,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -152,29 +155,34 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, if method.lower() == 'sinkhorn': return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': return sinkhorn_log(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) elif method.lower() == 'greenkhorn': return greenkhorn(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log) + stopThr=stopThr, verbose=verbose, log=log, + warn=warn) elif method.lower() == 'sinkhorn_stabilized': return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, + **kwargs) elif method.lower() == 'sinkhorn_epsilon_scaling': return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, + **kwargs) else: raise ValueError("Unknown method '%s'." % method) def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): + stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the loss @@ -245,6 +253,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -343,8 +353,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, raise ValueError("Unknown method '%s'." % method) -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, + verbose=False, log=False, warn=True, + **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -390,6 +401,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -497,9 +510,10 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) else: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['niter'] = ii log['u'] = u @@ -520,8 +534,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) -def sinkhorn_log(a, b, M, reg, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, + log=False, warn=True, **kwargs): r""" Solve the entropic regularization optimal transport problem in log space and return the OT matrix @@ -566,6 +580,8 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -698,9 +714,10 @@ def get_logT(u, v): if err < stopThr: break else: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['niter'] = ii @@ -716,7 +733,7 @@ def get_logT(u, v): def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, - log=False): + log=False, warn=True): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -761,6 +778,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, Stop threshold on error (>0) log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -859,9 +878,10 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, if stopThr_val <= stopThr: break else: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log["n_iter"] = ii @@ -876,7 +896,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, - log=False, **kwargs): + log=False, warn=True, **kwargs): r""" Solve the entropic regularization OT problem with log stabilization @@ -929,6 +949,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -1081,9 +1103,10 @@ def get_Gamma(alpha, beta, u, v): v = vprev break else: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: if n_hists: alpha = alpha[:, None] @@ -1119,7 +1142,7 @@ def get_Gamma(alpha, beta, u, v): def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, - log=False, **kwargs): + log=False, warn=True, **kwargs): r""" Solve the entropic regularization optimal transport problem with log stabilization and epsilon scaling. @@ -1171,6 +1194,9 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + Returns ------- gamma : array-like, shape (dim_a, dim_b) @@ -1266,9 +1292,10 @@ def get_reg(n): # exponential decreasing if err <= stopThr and ii > numItermin: break else: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['alpha'] = alpha log['beta'] = beta @@ -1309,7 +1336,7 @@ def projC(gamma, q): def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-4, verbose=False, log=False, **kwargs): + stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: @@ -1350,6 +1377,8 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -1374,23 +1403,24 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, return barycenter_sinkhorn(A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) elif method.lower() == 'sinkhorn_stabilized': return barycenter_stabilized(A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': return _barycenter_sinkhorn_log(A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): + stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: @@ -1428,6 +1458,8 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -1488,9 +1520,10 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) else: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['niter'] = ii return geometricBar(weights, UKv), log @@ -1499,7 +1532,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): + stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the entropic wasserstein barycenter in log-domain """ @@ -1549,9 +1582,10 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, G = log_bar[:, None] - log_KU else: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['niter'] = ii return nx.exp(log_bar), log @@ -1560,7 +1594,7 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): + stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization. The function solves the following optimization problem: @@ -1601,6 +1635,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -1682,9 +1718,10 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, print('{:5d}|{:8e}|'.format(ii, err)) else: - warnings.warn("Stabilized Sinkhorn did not converge." + - "Try a larger entropy `reg`" + - "Or a larger absorption threshold `tau`.") + if warn: + warnings.warn("Stabilized Sinkhorn did not converge." + + "Try a larger entropy `reg`" + + "Or a larger absorption threshold `tau`.") if log: log['niter'] = ii log['logu'] = np.log(u + 1e-16) @@ -1695,7 +1732,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-4, verbose=False, log=False, **kwargs): + stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): r"""Compute the debiased Sinkhorn barycenter of distributions A The function solves the following optimization problem: @@ -1713,7 +1750,7 @@ def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=1 the cost matrix for OT The algorithm used for solving the problem is the debiased Sinkhorn - algorithm as proposed in :ref:`[28] ` + algorithm as proposed in :ref:`[35] ` Parameters ---------- @@ -1735,6 +1772,9 @@ def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=1 Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + Returns @@ -1748,7 +1788,7 @@ def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=1 References ---------- - .. [28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + .. [35] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ @@ -1756,18 +1796,18 @@ def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=1 return _barycenter_debiased(A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) + warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': return _barycenter_debiased_log(A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): + stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the debiased sinkhorn barycenter of distributions A. """ @@ -1817,9 +1857,10 @@ def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000, '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) else: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['niter'] = ii return bar, log @@ -1828,7 +1869,8 @@ def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000, def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): + stopThr=1e-4, verbose=False, log=False, + warn=True): r"""Compute the debiased sinkhorn barycenter in log domain. """ @@ -1880,9 +1922,10 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, c = 0.5 * (c + log_bar - nx.logsumexp(M + c[:, None], axis=0)) else: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['niter'] = ii return nx.exp(log_bar), log @@ -1891,7 +1934,8 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-4, verbose=False, log=False, **kwargs): + stopThr=1e-4, verbose=False, log=False, + warn=True, **kwargs): r"""Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images. @@ -1931,6 +1975,8 @@ def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numIterm Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -1949,7 +1995,7 @@ def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numIterm Efficient optimal transportation on geometric domains. ACM Transactions on Graphics (TOG), 34(4), 66 - .. [28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th + .. [35] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ @@ -1957,20 +2003,21 @@ def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numIterm return _convolutional_barycenter2d(A, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, + log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': return _convolutional_barycenter2d_log(A, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, + **kwargs) else: raise ValueError("Unknown method '%s'." % method) def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, - log=False): + log=False, warn=True): r"""Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images. """ @@ -2029,9 +2076,10 @@ def convol_imgs(imgs): break else: - warnings.warn("Convolutional Sinkhorn did not converge. " - "Try a larger number of iterations `numItermax` " - "or a larger entropy `reg`.") + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") if log: log['niter'] = ii log['U'] = U @@ -2042,7 +2090,7 @@ def convol_imgs(imgs): def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, stopThr=1e-4, stabThr=1e-30, verbose=False, - log=False): + log=False, warn=True): r"""Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images in log-domain. """ @@ -2105,9 +2153,10 @@ def convol_img(log_img): G = log_bar[None, :, :] - log_KU else: - warnings.warn("Convolutional Sinkhorn did not converge. " - "Try a larger number of iterations `numItermax` " - "or a larger entropy `reg`.") + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") if log: log['niter'] = ii return nx.exp(log_bar), log @@ -2115,8 +2164,10 @@ def convol_img(log_img): return nx.exp(log_bar) -def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-4, verbose=False, log=False, **kwargs): +def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", + numItermax=10000, stopThr=1e-4, + verbose=False, log=False, warn=True, + **kwargs): r"""Compute the debiased sinkhorn barycenter of distributions A where A is a collection of 2D images. @@ -2134,7 +2185,7 @@ def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", - `reg` is the regularization strength scalar value The algorithm used for solving the problem is the debiased Sinkhorn scaling - algorithm as proposed in :ref:`[28] ` + algorithm as proposed in :ref:`[35] ` Parameters ---------- @@ -2156,6 +2207,9 @@ def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + Returns ------- @@ -2169,7 +2223,7 @@ def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", References ---------- - .. [28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + .. [35] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ @@ -2177,20 +2231,21 @@ def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", return _convolutional_barycenter2d_debiased(A, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, + log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': return _convolutional_barycenter2d_debiased_log(A, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, + **kwargs) else: raise ValueError("Unknown method '%s'." % method) def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, stopThr=1e-4, stabThr=1e-15, verbose=False, - log=False): + log=False, warn=True): r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions. """ @@ -2256,9 +2311,10 @@ def convol_imgs(imgs): if err < stopThr and ii > 20: break else: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['niter'] = ii log['U'] = U @@ -2269,7 +2325,7 @@ def convol_imgs(imgs): def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000, stopThr=1e-4, stabThr=1e-30, verbose=False, - log=False): + log=False, warn=True): r"""Compute the debiased barycenter of 2D images in log-domain. """ @@ -2332,9 +2388,10 @@ def convol_img(log_img): G = log_bar[None, :, :] - log_KU else: - warnings.warn("Convolutional Sinkhorn did not converge. " - "Try a larger number of iterations `numItermax` " - "or a larger entropy `reg`.") + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") if log: log['niter'] = ii return nx.exp(log_bar), log @@ -2343,7 +2400,7 @@ def convol_img(log_img): def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, - stopThr=1e-3, verbose=False, log=False): + stopThr=1e-3, verbose=False, log=False, warn=True): r""" Compute the unmixing of an observation with a given dictionary using Wasserstein distance @@ -2399,7 +2456,8 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, Print information along iterations log : bool, optional record log if True - + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -2459,9 +2517,10 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, if err < stopThr: break else: - warnings.warn("Unmixing algorithm did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + if warn: + warnings.warn("Unmixing algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['niter'] = ii return nx.sum(K0, axis=1), log @@ -2470,7 +2529,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, - stopThr=1e-6, verbose=False, log=False, **kwargs): + stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs): r'''Joint OT and proportion estimation for multi-source target shift as proposed in :ref:`[27] ` @@ -2520,10 +2579,12 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, Max number of iterations stopThr : float, optional Stop threshold on relative change in the barycenter (>0) - log : bool, optional - record log if True verbose : bool, optional (default=False) Controls the verbosity of the optimization algorithm + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -2628,9 +2689,10 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(ii, err)) else: - warnings.warn("Algorithm did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + if warn: + warnings.warn("Algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") bary = bary / nx.sum(bary) if log: @@ -2646,7 +2708,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, - log=False, **kwargs): + log=False, warn=True, **kwargs): r''' Solve the entropic regularization optimal transport problem and return the OT matrix from empirical data @@ -2696,6 +2758,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -2804,9 +2868,10 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if err <= stopThr: break else: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: dict_log["u"] = f dict_log["v"] = g @@ -2828,7 +2893,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, isLazy=False, - batchSize=100, verbose=False, log=False, **kwargs): + batchSize=100, verbose=False, log=False, warn=True, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -2879,6 +2944,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -2934,12 +3001,14 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', stopThr=stopThr, isLazy=isLazy, batchSize=batchSize, - verbose=verbose, log=log) + verbose=verbose, log=log, + warn=warn) else: f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, isLazy=isLazy, batchSize=batchSize, - verbose=verbose, log=log) + verbose=verbose, log=log, + warn=warn) bs = batchSize if isinstance(batchSize, int) else batchSize[0] range_s = range(0, ns, bs) @@ -2967,18 +3036,19 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if log: sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) + warn=warn, **kwargs) return sinkhorn_loss, log else: sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) + warn=warn, **kwargs) return sinkhorn_loss def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, - verbose=False, log=False, **kwargs): + verbose=False, log=False, warn=True, + **kwargs): r''' Compute the sinkhorn divergence loss from empirical data @@ -3043,6 +3113,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -3073,17 +3145,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, **kwargs) sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, **kwargs) sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) @@ -3100,22 +3172,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli else: sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, **kwargs) + verbose=verbose, log=log, + warn=warn, **kwargs) sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, **kwargs) + verbose=verbose, log=log, + warn=warn, **kwargs) sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, **kwargs) + verbose=verbose, log=log, + warn=warn, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) -def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True, - maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False): +def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, + restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09, + verbose=False, log=False): r""" Screening Sinkhorn Algorithm for Regularized Optimal Transport @@ -3145,48 +3221,36 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res Parameters ---------- - a : array-like, shape=(ns,) + a: array-like, shape=(ns,) samples weights in the source domain - - b : array-like, shape=(nt,) + b: array-like, shape=(nt,) samples weights in the target domain - - M : array-like, shape=(ns, nt) + M: array-like, shape=(ns, nt) Cost matrix - - reg : `float` + reg: `float` Level of the entropy regularisation - - ns_budget : `int`, default=None + ns_budget: `int`, default=None Number budget of points to be kept in the source domain. If it is None then 50% of the source sample points will be kept - - nt_budget : `int`, default=None + nt_budget: `int`, default=None Number budget of points to be kept in the target domain. If it is None then 50% of the target sample points will be kept - - uniform : `bool`, default=False + uniform: `bool`, default=False If `True`, the source and target distribution are supposed to be uniform, i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt` - restricted : `bool`, default=True If `True`, a warm-start initialization for the L-BFGS-B solver using a restricted Sinkhorn algorithm with at most 5 iterations - - maxiter : `int`, default=10000 + maxiter: `int`, default=10000 Maximum number of iterations in LBFGS solver - - maxfun : `int`, default=10000 + maxfun: `int`, default=10000 Maximum number of function evaluations in LBFGS solver - - pgtol : `float`, default=1e-09 + pgtol: `float`, default=1e-09 Final objective function accuracy in LBFGS solver - - verbose : `bool`, default=False + verbose: `bool`, default=False If `True`, display informations about the cardinals of the active sets and the parameters kappa and epsilon - Dependency ---------- To gain more efficiency, screenkhorn needs to call the "Bottleneck" From ace0e6770c270dc0a47bb8357ba23c534c73a44c Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 2 Nov 2021 16:29:41 +0100 Subject: [PATCH 20/25] fix ref number --- examples/barycenters/plot_debiased_barycenter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py index 6cfce7012..dc1058d9e 100644 --- a/examples/barycenters/plot_debiased_barycenter.py +++ b/examples/barycenters/plot_debiased_barycenter.py @@ -5,10 +5,10 @@ ================================= This example illustrates the computation of the debiased Sinkhorn barycenter -as proposed in [28]_. +as proposed in [35]_. -.. [28] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th +.. [35] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ From 40ae0de0d15bdc51082a2a297f22ee43d90cb9d3 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 2 Nov 2021 17:07:52 +0100 Subject: [PATCH 21/25] catch warnings in tests --- test/test_bregman.py | 84 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 70 insertions(+), 14 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index ebeb9592d..b1cdac746 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -6,6 +6,8 @@ # # License: MIT License +from itertools import product + import numpy as np import pytest @@ -13,7 +15,8 @@ from ot.backend import torch -def test_sinkhorn(): +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_sinkhorn(verbose, warn): # test sinkhorn n = 100 rng = np.random.RandomState(0) @@ -23,7 +26,7 @@ def test_sinkhorn(): M = ot.dist(x, x) - G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) + G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10, verbose=verbose, warn=warn) # check constraints np.testing.assert_allclose( @@ -31,8 +34,29 @@ def test_sinkhorn(): np.testing.assert_allclose( u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + with pytest.warns(UserWarning): + ot.sinkhorn(u, u, M, 1, stopThr=0, numItermax=1) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_log"]) +def test_convergence_warning(method): + # test sinkhorn + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + M = ot.dist(x, x) + + with pytest.warns(UserWarning): + ot.sinkhorn(u, u, M, 1., method=method, stopThr=0, numItermax=1) + with pytest.warns(UserWarning): + ot.sinkhorn2(u, u, M, 1, method=method, stopThr=0, numItermax=1) + -def test_sinkhorn_multi_b(): +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_sinkhorn_multi_b(verbose, warn): # test sinkhorn n = 10 rng = np.random.RandomState(0) @@ -47,7 +71,7 @@ def test_sinkhorn_multi_b(): loss0, log = ot.sinkhorn(u, b, M, .1, stopThr=1e-10, log=True) - loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10) for k in range(3)] + loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10, verbose=verbose, warn=warn) for k in range(3)] # check constraints np.testing.assert_allclose( loss0, loss, atol=1e-06) # cf convergence sinkhorn @@ -255,7 +279,7 @@ def test_sinkhorn_variants_log(): Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) Ges, loges = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True) + u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,) G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) # check values @@ -265,7 +289,8 @@ def test_sinkhorn_variants_log(): np.testing.assert_allclose(G0, G_green, atol=1e-5) -def test_sinkhorn_variants_log_multib(): +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_sinkhorn_variants_log_multib(verbose, warn): # test sinkhorn n = 50 rng = np.random.RandomState(0) @@ -278,16 +303,20 @@ def test_sinkhorn_variants_log_multib(): M = ot.dist(x, x) G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True) - Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) - Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True, + verbose=verbose, warn=warn) + Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True, + verbose=verbose, warn=warn) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Gl, atol=1e-05) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]) -def test_barycenter(nx, method): +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter(nx, method, verbose, warn): n_bins = 100 # nb bins # Gaussian distributions @@ -314,7 +343,7 @@ def test_barycenter(nx, method): ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method) else: # wasserstein - bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method) + bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn) bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) @@ -324,8 +353,10 @@ def test_barycenter(nx, method): ot.bregman.barycenter(A_nx, M_nx, reg, log=True) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) -def test_barycenter_debiased(nx, method): +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter_debiased(nx, method, verbose, warn): n_bins = 100 # nb bins # Gaussian distributions @@ -352,7 +383,8 @@ def test_barycenter_debiased(nx, method): with pytest.raises(NotImplementedError): ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) else: - bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method) + bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, + verbose=verbose, warn=warn) bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) @@ -362,6 +394,30 @@ def test_barycenter_debiased(nx, method): ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_convergence_warning_barycenters(method): + n_bins = 100 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + reg = 0.1 + with pytest.warns(UserWarning): + ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1) + + def test_barycenter_stabilization(nx): n_bins = 100 # nb bins From ebd5f6aa77e2d573642117bc92c03281c5c2d30c Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 2 Nov 2021 17:23:17 +0100 Subject: [PATCH 22/25] add contrib to readme + change ref number --- README.md | 8 ++++++-- examples/barycenters/plot_debiased_barycenter.py | 5 ++--- ot/bregman.py | 10 +++++----- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index cfb974479..ff32c53be 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,8 @@ POT provides the following generic OT solvers (links to examples): * [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7]. * Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). * Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4]. -* Sinkhorn divergence [23] and entropic regularization OT from empirical data. +* Sinkhorn divergence [23] and entropic regularization OT from empirical data. +* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37] * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. * Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale). * [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]) @@ -188,7 +189,7 @@ The contributors to this library are * [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic solvers) * [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home) * [Vayer Titouan](https://tvayer.github.io/) (Gromov-Wasserstein -, Fused-Gromov-Wasserstein) -* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT) +* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT, Debiased barycenters) * [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein) * [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) * [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) @@ -293,3 +294,6 @@ You can also post bug reports and feature requests in Github issues. Make sure t (2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on Machine Learning (pp. 4104-4113). PMLR. + +[37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International +Conference on Machine Learning, PMLR 119:4692-4701, 2020 \ No newline at end of file diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py index dc1058d9e..6dea42f13 100644 --- a/examples/barycenters/plot_debiased_barycenter.py +++ b/examples/barycenters/plot_debiased_barycenter.py @@ -5,10 +5,10 @@ ================================= This example illustrates the computation of the debiased Sinkhorn barycenter -as proposed in [35]_. +as proposed in [37]_. -.. [35] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th +.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ @@ -16,7 +16,6 @@ # # License: MIT License -# sphinx_gallery_thumbnail_number = 4 import os from pathlib import Path diff --git a/ot/bregman.py b/ot/bregman.py index 641fcee9f..5f4f7c45d 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1750,7 +1750,7 @@ def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=1 the cost matrix for OT The algorithm used for solving the problem is the debiased Sinkhorn - algorithm as proposed in :ref:`[35] ` + algorithm as proposed in :ref:`[37] ` Parameters ---------- @@ -1788,7 +1788,7 @@ def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=1 References ---------- - .. [35] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ @@ -1995,7 +1995,7 @@ def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numIterm Efficient optimal transportation on geometric domains. ACM Transactions on Graphics (TOG), 34(4), 66 - .. [35] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ @@ -2185,7 +2185,7 @@ def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", - `reg` is the regularization strength scalar value The algorithm used for solving the problem is the debiased Sinkhorn scaling - algorithm as proposed in :ref:`[35] ` + algorithm as proposed in :ref:`[37] ` Parameters ---------- @@ -2223,7 +2223,7 @@ def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", References ---------- - .. [35] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ From 8c2e1f2db42ac817d36f7cf81a46efdd342929c1 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 2 Nov 2021 20:58:19 +0100 Subject: [PATCH 23/25] fix convolution example + gallery thumbnails --- examples/barycenters/plot_barycenter_1D.py | 20 ++++------------- .../plot_barycenter_lp_vs_entropic.py | 2 +- .../barycenters/plot_debiased_barycenter.py | 22 +++++++++---------- ot/bregman.py | 6 ++--- 4 files changed, 18 insertions(+), 32 deletions(-) diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py index 00bcfe67c..2373e99a6 100644 --- a/examples/barycenters/plot_barycenter_1D.py +++ b/examples/barycenters/plot_barycenter_1D.py @@ -18,7 +18,7 @@ # # License: MIT License -# sphinx_gallery_thumbnail_number = 4 +# sphinx_gallery_thumbnail_number = 1 import numpy as np import matplotlib.pyplot as plt @@ -50,18 +50,6 @@ M = ot.utils.dist0(n) M /= M.max() -############################################################################## -# Plot data -# --------- - -#%% plot the distributions - -# plt.figure(1, figsize=(6.4, 3)) -# for i in range(n_distributions): -# plt.plot(x, A[:, i]) -# plt.title('Distributions') -# plt.tight_layout() - ############################################################################## # Barycenter computation # ---------------------- @@ -78,7 +66,7 @@ reg = 1e-3 bary_wass = ot.bregman.barycenter(A, M, reg, weights) -f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True) +f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True, num=1) ax1.plot(x, A, color="black") ax1.set_title('Distributions') @@ -109,7 +97,7 @@ B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights) #%% plot interpolation -plt.figure() +plt.figure(2) cmap = plt.cm.get_cmap('viridis') verts = [] @@ -132,7 +120,7 @@ plt.title('Barycenter interpolation with l2') plt.tight_layout() -plt.figure(4) +plt.figure(3) cmap = plt.cm.get_cmap('viridis') verts = [] zs = alpha_list diff --git a/examples/barycenters/plot_barycenter_lp_vs_entropic.py b/examples/barycenters/plot_barycenter_lp_vs_entropic.py index 57a6bac5e..6502f16d8 100644 --- a/examples/barycenters/plot_barycenter_lp_vs_entropic.py +++ b/examples/barycenters/plot_barycenter_lp_vs_entropic.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ ================================================================================= -1D Wasserstein barycenter comparison between exact LP and entropic regularization +1D Wasserstein barycenter: exact LP vs entropic regularization ================================================================================= This example illustrates the computation of regularized Wasserstein Barycenter diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py index 6dea42f13..2a603dd9c 100644 --- a/examples/barycenters/plot_debiased_barycenter.py +++ b/examples/barycenters/plot_debiased_barycenter.py @@ -15,6 +15,7 @@ # Author: Hicham Janati # # License: MIT License +# sphinx_gallery_thumbnail_number = 3 import os from pathlib import Path @@ -63,7 +64,8 @@ labels = ["Sinkhorn barycenter", "Debiased barycenter"] colors = ["indianred", "gold"] -f, axes = plt.subplots(1, len(epsilons), tight_layout=True, sharey=True, figsize=(12, 4)) +f, axes = plt.subplots(1, len(epsilons), tight_layout=True, sharey=True, + figsize=(12, 4), num=1) for ax, eps, bar, bar_debiased in zip(axes, epsilons, bars, bars_debiased): ax.plot(A[:, 0], color="k", ls="--", label="Input data", alpha=0.3) ax.plot(A[:, 1], color="k", ls="--", alpha=0.3) @@ -79,20 +81,16 @@ # --------------------------------- this_file = os.path.realpath('__file__') data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') -f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2] -f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2] -f3 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] +f1 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2] +f2 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] -f1 = f1 / np.sum(f1) -f2 = f2 / np.sum(f2) -f3 = f3 / np.sum(f3) - -A = np.array([f1, f2, f3]) +A = np.asarray([f1, f2]) + 1e-2 +A /= A.sum(axis=(1, 2))[:, None, None] ############################################################################## # Display the input images -fig, axes = plt.subplots(1, 3, figsize=(7, 4)) +fig, axes = plt.subplots(1, 2, figsize=(7, 4), num=2) for ax, img in zip(axes, A): ax.imshow(img, cmap="Greys") ax.axis("off") @@ -109,13 +107,13 @@ epsilons = [5e-3, 7e-3, 1e-2] for eps in epsilons: bar = convolutional_barycenter2d(A, eps) - bar_debiased = convolutional_barycenter2d_debiased(A, eps) + bar_debiased, log = convolutional_barycenter2d_debiased(A, eps, log=True) bars_sinkhorn.append(bar) bars_debiased.append(bar_debiased) titles = ["Sinkhorn", "Debiased"] all_bars = [bars_sinkhorn, bars_debiased] -fig, axes = plt.subplots(2, 3, figsize=(8, 6)) +fig, axes = plt.subplots(2, 3, figsize=(8, 6), num=3) for jj, (method, ax_row, bars) in enumerate(zip(titles, axes, all_bars)): for ii, (ax, img, eps) in enumerate(zip(ax_row, bars, epsilons)): ax.imshow(img, cmap="Greys") diff --git a/ot/bregman.py b/ot/bregman.py index 5f4f7c45d..502a8f5c3 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -2165,7 +2165,7 @@ def convol_img(log_img): def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", - numItermax=10000, stopThr=1e-4, + numItermax=10000, stopThr=1e-3, verbose=False, log=False, warn=True, **kwargs): r"""Compute the debiased sinkhorn barycenter of distributions A @@ -2244,7 +2244,7 @@ def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, - stopThr=1e-4, stabThr=1e-15, verbose=False, + stopThr=1e-3, stabThr=1e-15, verbose=False, log=False, warn=True): r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions. """ @@ -2324,7 +2324,7 @@ def convol_imgs(imgs): def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000, - stopThr=1e-4, stabThr=1e-30, verbose=False, + stopThr=1e-3, stabThr=1e-30, verbose=False, log=False, warn=True): r"""Compute the debiased barycenter of 2D images in log-domain. """ From 8d37431c1016d885d0ac7d073b4504eea16febff Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 2 Nov 2021 22:31:55 +0100 Subject: [PATCH 24/25] increase coverage --- ot/bregman.py | 14 +++--- test/test_bregman.py | 107 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 100 insertions(+), 21 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 502a8f5c3..786f151e0 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -486,7 +486,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', ii) + warnings.warn('Warning: numerical errors at iteration %d' % ii) u = uprev v = vprev break @@ -1052,8 +1052,8 @@ def get_Gamma(alpha, beta, u, v): vprev = v # sinkhorn update - v = b / (nx.dot(K.T, u) + 1e-16) - u = a / (nx.dot(K, v) + 1e-16) + v = b / (nx.dot(K.T, u)) + u = a / (nx.dot(K, v)) # remove numerical problems and store them in K if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: @@ -1084,8 +1084,6 @@ def get_Gamma(alpha, beta, u, v): if log: log['err'].append(err) - if err < stopThr: - break if verbose: if ii % (print_period * 20) == 0: print( @@ -1098,7 +1096,7 @@ def get_Gamma(alpha, beta, u, v): if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration', ii) + warnings.warn('Numerical errors at iteration %d' % ii) u = uprev v = vprev break @@ -1682,11 +1680,11 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, for ii in range(numItermax): qprev = q Kv = nx.dot(K, v) - u = A / (Kv + 1e-16) + u = A / Kv Ktu = nx.dot(K.T, u) q = geometricBar(weights, Ktu) Q = q[:, None] - v = Q / (Ktu + 1e-16) + v = Q / Ktu absorbing = False if nx.any(u > tau) or nx.any(v > tau): absorbing = True diff --git a/test/test_bregman.py b/test/test_bregman.py index b1cdac746..04d450c65 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -10,6 +10,7 @@ import numpy as np import pytest +from torch._C import Value import ot from ot.backend import torch @@ -39,24 +40,87 @@ def test_sinkhorn(verbose, warn): @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", "sinkhorn_log"]) def test_convergence_warning(method): # test sinkhorn n = 100 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - u = ot.utils.unif(n) - M = ot.dist(x, x) + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + A = np.asarray([a1, a2]).T + M = ot.utils.dist0(n) with pytest.warns(UserWarning): - ot.sinkhorn(u, u, M, 1., method=method, stopThr=0, numItermax=1) + ot.sinkhorn(a1, a2, M, 1., method=method, stopThr=0, numItermax=1) + + if method in ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]: + with pytest.warns(UserWarning): + ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) + with pytest.warns(UserWarning): + ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) + + +def test_not_impemented_method(): + # test sinkhorn + w = 10 + n = w ** 2 + rng = np.random.RandomState(42) + A_img = rng.rand(2, w, w) + A_flat = A_img.reshape(n, 2) + a1, a2 = A_flat.T + M_flat = ot.utils.dist0(n) + not_implemented = "new_method" + reg = 0.01 + with pytest.raises(ValueError): + ot.sinkhorn(a1, a2, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.sinkhorn2(a1, a2, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.barycenter(A_flat, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.barycenter_debiased(A_flat, M_flat, reg, + method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.convolutional_barycenter2d(A_img, reg, + method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, + method=not_implemented) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +def test_nan_warning(method): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + + M = ot.utils.dist0(n) + reg = 0 with pytest.warns(UserWarning): - ot.sinkhorn2(u, u, M, 1, method=method, stopThr=0, numItermax=1) + # warn set to False to avoid catching a convergence warning instead + ot.sinkhorn(a1, a2, M, reg, method=method, warn=False) -@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) -def test_sinkhorn_multi_b(verbose, warn): +def test_sinkhorn_stabilization(): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + M = ot.utils.dist0(n) + reg = 1e-5 + loss1 = ot.sinkhorn2(a1, a2, M, reg, method="sinkhorn_log") + loss2 = ot.sinkhorn2(a1, a2, M, reg, tau=1, method="sinkhorn_stabilized") + np.testing.assert_allclose( + loss1, loss2, atol=1e-06) # cf convergence sinkhorn + + +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_log"], + [True, False], [True, False])) +def test_sinkhorn_multi_b(method, verbose, warn): # test sinkhorn n = 10 rng = np.random.RandomState(0) @@ -69,12 +133,14 @@ def test_sinkhorn_multi_b(verbose, warn): M = ot.dist(x, x) - loss0, log = ot.sinkhorn(u, b, M, .1, stopThr=1e-10, log=True) + loss0, log = ot.sinkhorn(u, b, M, .1, method=method, stopThr=1e-10, + log=True) - loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10, verbose=verbose, warn=warn) for k in range(3)] + loss = [ot.sinkhorn2(u, b[:, k], M, .1, method=method, stopThr=1e-10, + verbose=verbose, warn=warn) for k in range(3)] # check constraints np.testing.assert_allclose( - loss0, loss, atol=1e-06) # cf convergence sinkhorn + loss0, loss, atol=1e-4) # cf convergence sinkhorn def test_sinkhorn_backends(nx): @@ -155,6 +221,12 @@ def test_sinkhorn_empty(): M = ot.dist(x, x) + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method="sinkhorn_log", + verbose=True, log=True) + # check constraints + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True) # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) @@ -396,7 +468,8 @@ def test_barycenter_debiased(nx, method, verbose, warn): @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) def test_convergence_warning_barycenters(method): - n_bins = 100 # nb bins + w = 10 + n_bins = w ** 2 # nb bins # Gaussian distributions a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std @@ -404,6 +477,8 @@ def test_convergence_warning_barycenters(method): # creating matrix A containing all distributions A = np.vstack((a1, a2)).T + A_img = A.reshape(2, w, w) + A_img /= A_img.sum((1, 2))[:, None, None] # loss matrix + normalization M = ot.utils.dist0(n_bins) @@ -416,6 +491,12 @@ def test_convergence_warning_barycenters(method): ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1) with pytest.warns(UserWarning): ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.convolutional_barycenter2d(A_img, reg, weights, + method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, weights, + method=method, numItermax=1) def test_barycenter_stabilization(nx): From 6bd076b9e9eb45a32874be0fdddf73741378e41f Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 2 Nov 2021 22:50:07 +0100 Subject: [PATCH 25/25] fix flake --- test/test_bregman.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 04d450c65..edfe9c305 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -10,7 +10,6 @@ import numpy as np import pytest -from torch._C import Value import ot from ot.backend import torch