Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,6 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil

[59] Taylor A. B. (2017). [Convex interpolation and performance estimation of first-order methods for convex optimization.](https://dial.uclouvain.be/pr/boreal/object/boreal%3A182881/datastream/PDF_01/view) PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017.

[60] Feydy, J., Roussillon, P., Trouvé, A., & Gori, P. (2019). [Fast and scalable optimal transport for brain tractograms](https://arxiv.org/pdf/2107.02010.pdf). In Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part III 22 (pp. 636-644). Springer International Publishing.

[61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. (2021). [Kernel operations on the gpu, with autodiff, without memory overflows](https://www.jmlr.org/papers/volume22/20-275/20-275.pdf). The Journal of Machine Learning Research, 22(1), 3457-3462.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
+ Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559)
+ Further upgraded unbalanced OT solvers for more flexibility and future use (PR #551)
+ New API function `ot.solve_sample` for solving OT problems from empirical samples (PR #563)
+ Wrapper for `geomloss`` solver on empirical samples (PR #571)
+ Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578)
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)

Expand Down
6 changes: 4 additions & 2 deletions ot/bregman/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,17 @@

from ._dictionary import (unmix)

from ._geomloss import (empirical_sinkhorn2_geomloss, geomloss)


__all__ = ['geometricBar', 'geometricMean', 'projR', 'projC',
'sinkhorn', 'sinkhorn2', 'sinkhorn_knopp', 'sinkhorn_log',
'greenkhorn', 'sinkhorn_stabilized', 'sinkhorn_epsilon_scaling',
'barycenter', 'barycenter_sinkhorn', 'free_support_sinkhorn_barycenter',
'barycenter_stabilized', 'barycenter_debiased', 'jcpot_barycenter',
'convolutional_barycenter2d', 'convolutional_barycenter2d_debiased',
'empirical_sinkhorn', 'empirical_sinkhorn2',
'empirical_sinkhorn_divergence',
'empirical_sinkhorn', 'empirical_sinkhorn2', 'empirical_sinkhorn2_geomloss'
'empirical_sinkhorn_divergence', 'geomloss',
'screenkhorn',
'unmix'
]
216 changes: 216 additions & 0 deletions ot/bregman/_geomloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
"""
Wrapper functions for geomloss
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
#
# License: MIT License

import numpy as np
try:
import geomloss
from geomloss import SamplesLoss
import torch
from torch.autograd import grad
from ..utils import get_backend, LazyTensor, dist
except ImportError:
geomloss = False


def get_sinkhorn_geomloss_lazytensor(X_a, X_b, f, g, a, b, metric='sqeuclidean', blur=0.1, nx=None):
""" Get a LazyTensor of sinkhorn solution T = exp((f+g^T-C)/reg)*(ab^T)

Parameters
----------
X_a : array-like, shape (n_samples_a, dim)
samples in the source domain
X_torch: array-like, shape (n_samples_b, dim)
samples in the target domain
f : array-like, shape (n_samples_a,)
First dual potentials (log space)
g : array-like, shape (n_samples_b,)
Second dual potentials (log space)
metric : str, default='sqeuclidean'
Metric used for the cost matrix computation
blur : float, default=1e-1
blur term (blur=sqrt(reg)) >0
nx : Backend(), default=None
Numerical backend used


Returns
-------
T : LazyTensor
Lowrank tensor T = exp((f+g^T-C)/reg)*(ab^T)
"""

if nx is None:
nx = get_backend(X_a, X_b, f, g)

shape = (X_a.shape[0], X_b.shape[0])

def func(i, j, X_a, X_b, f, g, a, b, metric, blur):
if metric == 'sqeuclidean':
C = dist(X_a[i], X_b[j], metric=metric) / 2
else:
C = dist(X_a[i], X_b[j], metric=metric)
return nx.exp((f[i, None] + g[None, j] - C) / (blur**2)) * (a[i, None] * b[None, j])

T = LazyTensor(shape, func, X_a=X_a, X_b=X_b, f=f, g=g, a=a, b=b, metric=metric, blur=blur)

return T


def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', scaling=0.95,
verbose=False, debias=False, log=False, backend='auto'):
r""" Solve the entropic regularization optimal transport problem with geomloss

The function solves the following optimization problem:

.. math::
\gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)

s.t. \gamma 1 = a

\gamma^T 1= b

\gamma\geq 0

where :

- :math:`C` is the cost matrix such that :math:`C_{i,j}=d(x_i^s,x_j^t)` and
:math:`d` is a metric.
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j}\gamma_{i,j}\log(\gamma_{i,j})-\gamma_{i,j}+1`
- :math:`a` and :math:`b` are source and target weights (sum to 1)

The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in and computed in log space for
better stability and epsilon-scaling. The solution is computed ina lzy way
using the Geomloss [60] and the KeOps library [61].

Parameters
----------
X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
a : array-like, shape (n_samples_a,), default=None
samples weights in the source domain
b : array-like, shape (n_samples_b,), default=None
samples weights in the target domain
metric : str, default='sqeuclidean'
Metric used for the cost matrix computation Only acepted values are
'sqeuclidean' and 'euclidean'.
scaling : float, default=0.95
Scaling parameter used for epsilon scaling. Value close to one promote
precision while value close to zero promote speed.
verbose : bool, default=False
Print information
debias : bool, default=False
Use the debiased version of Sinkhorn algorithm [12]_.
log : bool, default=False
Return log dictionary containing all computed objects
backend : str, default='auto'
Numerical backend for geomloss. Only 'auto' and 'tensorized' 'online'
and 'multiscale' are accepted values.

Returns
-------
value : float
OT value
log : dict
Log dictionary return only if log==True in parameters

References
----------

.. [60] Feydy, J., Roussillon, P., Trouvé, A., & Gori, P. (2019). [Fast
and scalable optimal transport for brain tractograms. In Medical Image
Computing and Computer Assisted Intervention–MICCAI 2019: 22nd
International Conference, Shenzhen, China, October 13–17, 2019,
Proceedings, Part III 22 (pp. 636-644). Springer International
Publishing.

.. [61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G.
(2021). Kernel operations on the gpu, with autodiff, without memory
overflows. The Journal of Machine Learning Research, 22(1), 3457-3462.

"""

if geomloss:

nx = get_backend(X_s, X_t, a, b)

if nx.__name__ not in ['torch', 'numpy']:
raise ValueError('geomloss only support torch or numpy backend')

if a is None:
a = nx.ones(X_s.shape[0], type_as=X_s) / X_s.shape[0]
if b is None:
b = nx.ones(X_t.shape[0], type_as=X_t) / X_t.shape[0]

if nx.__name__ == 'numpy':
X_s_torch = torch.tensor(X_s)
X_t_torch = torch.tensor(X_t)

a_torch = torch.tensor(a)
b_torch = torch.tensor(b)

else:
X_s_torch = X_s
X_t_torch = X_t

a_torch = a
b_torch = b

# after that we are all in torch

# set blur value and p
if metric == 'sqeuclidean':
p = 2
blur = np.sqrt(reg / 2) # because geomloss divides cost by two
elif metric == 'euclidean':
p = 1
blur = np.sqrt(reg)
else:
raise ValueError('geomloss only supports sqeuclidean and euclidean metrics')

# force gradients for computing dual
a_torch.requires_grad = True
b_torch.requires_grad = True

loss = SamplesLoss(loss='sinkhorn', p=p, blur=blur, backend=backend, debias=debias, scaling=scaling, verbose=verbose)

# compute value
value = loss(a_torch, X_s_torch, b_torch, X_t_torch) # linear + entropic/KL reg?

# get dual potentials
f, g = grad(value, [a_torch, b_torch])

if metric == 'sqeuclidean':
value *= 2 # because geomloss divides cost by two

if nx.__name__ == 'numpy':
f = f.cpu().detach().numpy()
g = g.cpu().detach().numpy()
value = value.cpu().detach().numpy()

if log:
log = {}
log['f'] = f
log['g'] = g
log['value'] = value

log['lazy_plan'] = get_sinkhorn_geomloss_lazytensor(X_s, X_t, f, g, a, b, metric=metric, blur=blur, nx=nx)

return value, log

else:
return value

else:
raise ImportError('geomloss not installed')
35 changes: 32 additions & 3 deletions ot/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .lp import emd2, wasserstein_1d
from .backend import get_backend
from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced
from .bregman import sinkhorn_log, empirical_sinkhorn2
from .bregman import sinkhorn_log, empirical_sinkhorn2, empirical_sinkhorn2_geomloss
from .partial import partial_wasserstein_lagrange
from .smooth import smooth_ot_dual
from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2,
Expand All @@ -23,6 +23,8 @@
from .gaussian import empirical_bures_wasserstein_distance
from .factored import factored_optimal_transport

lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale']


def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None,
Expand Down Expand Up @@ -865,7 +867,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None,

def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL",
unbalanced=None,
unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100,
unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95,
potentials_init=None, X_init=None, tol=None, verbose=False):
r"""Solve the discrete optimal transport problem using the samples in the source and target domains.

Expand Down Expand Up @@ -922,6 +924,10 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
Maximum number of iteration, by default None (default values in each solvers)
plan_init : array_like, shape (dim_a, dim_b), optional
Initialization of the OT plan for iterative methods, by default None
rank : int, optional
Rank of the OT matrix for lazy solers (method='factored'), by default 100
scaling : float, optional
Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95
potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
Initialization of the OT dual potentials for iterative methods, by default None
tol : _type_, optional
Expand All @@ -939,6 +945,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
- res.potentials : OT dual potentials
- res.value : Optimal value of the optimization problem
- res.value_linear : Linear OT loss with the optimal OT plan
- res.lazy_plan : Lazy OT plan (when ``lazy=True`` or lazy method)

See :any:`OTResult` for more information.

Expand Down Expand Up @@ -1148,7 +1155,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t

"""

if method is not None and method.lower() in ['1d', 'gaussian', 'lowrank', 'factored']:
if method is not None and method.lower() in lst_method_lazy:
lazy0 = lazy
lazy = True

Expand Down Expand Up @@ -1221,6 +1228,28 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
if not lazy0: # store plan if not lazy
plan = lazy_plan[:]

elif method.startswith('geomloss'): # Geomloss solver for entropi OT

split_method = method.split('_')
if len(split_method) == 2:
backend = split_method[1]
else:
if lazy0 is None:
backend = 'auto'
elif lazy0:
backend = 'online'
else:
backend = 'tensorized'

value, log = empirical_sinkhorn2_geomloss(X_a, X_b, reg=reg, a=a, b=b, metric=metric, log=True, verbose=verbose, scaling=scaling, backend=backend)

lazy_plan = log['lazy_plan']
if not lazy0: # store plan if not lazy
plan = lazy_plan[:]

# return scaled potentials (to be consistent with other solvers)
potentials = (log['f'] / (lazy_plan.blur**2), log['g'] / (lazy_plan.blur**2))

elif reg is None or reg == 0: # exact OT

if unbalanced is None: # balanced EMD solver not available for lazy
Expand Down
2 changes: 1 addition & 1 deletion ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def reduce_lazytensor(a, func, axis=None, nx=None, batch_size=100):
"""

if nx is None:
nx = get_backend(a[0])
nx = get_backend(a[0:1])

if axis is None:
res = 0.0
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ jaxlib
tensorflow
pytest
torch_geometric
cvxpy
cvxpy
geomloss
pykeops
Loading