Skip to content

Commit

Permalink
add Hungarian solver from optax (#598)
Browse files Browse the repository at this point in the history
* add Hungarian solver

* bump optax

* optax.

* add `wass_p` to docs

* Correct mistake on defining norm

* fixes

* add combinatorial

* using anonymous references for Wikipedia entries

* `mu.norm` does not return ||.||^2!!

* fix doc for norm
  • Loading branch information
marcocuturi authored Nov 19, 2024
1 parent fd18299 commit b479e5f
Show file tree
Hide file tree
Showing 11 changed files with 231 additions and 19 deletions.
1 change: 1 addition & 0 deletions docs/geometry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Cost Functions
costs.SqEuclidean
costs.RegTICost
costs.Euclidean
costs.EuclideanP
costs.Cosine
costs.Arccos
costs.Bures
Expand Down
17 changes: 14 additions & 3 deletions docs/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ Glossary
evaluations of :math:`c` on various pairs of points,
:math:`C=[c(x_i, y_j)]_{ij}`.

Hungarian algorithm
Combinatorial algorithm proposed by Harold Kuhn to solve the
:term:`optimal matching problem`. See the
`Wikipedia definition <https://en.wikipedia.org/wiki/Hungarian_algorithm>`__
.

implicit differentiation
Formula used to compute the vector-Jacobian
product of the minimizer of an optimization procedure that leverages
Expand Down Expand Up @@ -236,7 +242,7 @@ Glossary
measurable set :math:`B`, :math:`T\#\mu(B)=\mu(T^{-1}(B))`. Intuitively,
it is the measure obtained by applying the map :math:`T` to all points
described in the support of :math:`\mu`. See also the
`Wikipedia definition <https://en.wikipedia.org/wiki/push-forward_measure>`_.
`Wikipedia definition <https://en.wikipedia.org/wiki/push-forward_measure>`__.

optimal transport
Theory that characterizes efficient transformations between probability
Expand All @@ -245,6 +251,11 @@ Glossary
whereas computational aspects become relevant when estimating such
transforms from samples.

optimal matching problem
Instance of the :term:`Kantorovich problem` where both marginal weight
vectors :math:`a,b` are equal, and set both to a uniform weight vector
of the form :math:`(\tfrac{1}{n},\dots,\tfrac{1}{n})\in\mathbb{R}^n`.

Sinkhorn algorithm
Fixed point iteration that solves the
:term:`entropy-regularized optimal transport` problem (EOT).
Expand Down Expand Up @@ -340,7 +351,7 @@ Glossary
:term:`ground cost` function that is equal to the optimal objective
reached when solving the :term:`Kantorovich problem`. The Wasserstein
distance is truly a distance (in the sense that it satisfies all 3
`metric axioms <https://en.wikipedia.org/wiki/Metric_space#Definition>`_
), as long as the :term:`ground cost` is itself a distance to a power
`metric axioms <https://en.wikipedia.org/wiki/Metric_space#Definition>`__
) if the :term:`ground cost` is itself a distance to a power
:math:`p\leq 1`, and the :math:`p` root of the objective of the
:term:`Kantorovich problem` is used.
2 changes: 2 additions & 0 deletions docs/spelling/technical.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ centroids
checkpointing
chromatin
collinear
combinatorial
covariance
covariances
dataclass
Expand Down Expand Up @@ -165,6 +166,7 @@ transcriptome
undirected
univariate
unnormalized
unregularized
unscaled
url
vectorized
Expand Down
23 changes: 18 additions & 5 deletions docs/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,24 @@ The :mod:`~ott.tools` package contains high level functions that build on
outputs produced by lower-level components in the toolbox, such as
:mod:`~ott.solvers`.

In particular, we provide user-friendly APIs to compute Sinkhorn divergences
:cite:`genevay:18,sejourne:19`, sliced Wasserstein distances :cite:`rabin:12`,
differentiable approximations to ranks and quantile functions :cite:`cuturi:19`,
and various tools to study Gaussians with the 2-Wasserstein metric
:cite:`gelbrich:90,delon:20`, etc.
In particular, we provide user-friendly APIs to unregularized OT quantities,
such as the :term:`Wasserstein distance` for two point clouds of the same size.
We also provide functions to pad efficiently point clouds when doing large scale
OT between them in parallel, implementations of the Sinkhorn
divergence :cite:`genevay:18,sejourne:19`, sliced Wasserstein distances
:cite:`rabin:12`, differentiable approximations to ranks and quantile functions
:cite:`cuturi:19`, and various tools to study Gaussians with the
2-:term:`Wasserstein distance` :cite:`gelbrich:90,delon:20`.

Unregularized Optimal Transport
-------------------------------
.. autosummary::
:toctree: _autosummary

unreg.hungarian
unreg.HungarianOutput
unreg.wassdis_p


Segmented Sinkhorn
------------------
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"lineax>=0.0.7",
"numpy>=1.20.0",
"typing_extensions; python_version <= '3.9'",
"optax>=0.2.4",
]
keywords = [
"optimal transport",
Expand Down Expand Up @@ -59,7 +60,7 @@ Changelog = "https://github.com/ott-jax/ott/releases"
[project.optional-dependencies]
neural = [
"flax>=0.6.6",
"optax>=0.1.1",
"optax>=0.2.4",
"diffrax>=0.4.1",
]
dev = [
Expand Down
27 changes: 27 additions & 0 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,33 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(*aux_data)


@jtu.register_pytree_node_class
class EuclideanP(TICost):
r""":math:`p`-power of Euclidean norm.
Uses custom implementation of `norm` to avoid `NaN` values when
differentiating the norm of :math:`x-x`.
Args:
p: Power used to raise Euclidean norm, in :math:`[1, +\infty)`.
"""

def __init__(self, p: float):
super().__init__()
self.p = p

def h(self, z: jnp.ndarray) -> float: # noqa: D102
return mu.norm(z, ord=2) ** (self.p)

def tree_flatten(self): # noqa: D102
return (), (self.p,)

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(*aux_data)


@jtu.register_pytree_node_class
class RegTICost(TICost):
r"""Regularized translation-invariant cost.
Expand Down
24 changes: 15 additions & 9 deletions src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,25 @@ class Geometry:
cost_matrix: Cost matrix of shape ``[n, m]``.
kernel_matrix: Kernel matrix of shape ``[n, m]``.
epsilon: Regularization parameter. If ``None`` and either
``relative_epsilon = True`` or ``relative_epsilon = None``, this defaults
to the value computed in :attr:`mean_cost_matrix` / 20. If passed as a
``relative_epsilon = True`` or ``relative_epsilon = None`` or
``relative_epsilon = str`` where ``str`` can be either ``mean`` or ``std``
, this value defaults to a multiple of :attr:`std_cost_matrix`
(or :attr:`mean_cost_matrix` if ``str`` is ``mean``), where that multiple
is set as ``DEFAULT_SCALE`` in ``epsilon_scheduler.py```.
If passed as a
``float``, then the regularizer that is ultimately used is either that
``float`` value (if ``relative_epsilon = False`` or ``None``) or that
``float`` times the :attr:`mean_cost_matrix`
(if ``relative_epsilon = True``). Look for
``float`` times the :attr:`std_cost_matrix` (if
``relative_epsilon = True`` or ``relative_epsilon = `std```) or
:attr:`mean_cost_matrix` (if ``relative_epsilon = `mean```). Look for
:class:`~ott.geometry.epsilon_scheduler.Epsilon` when passed as a
scheduler.
relative_epsilon: when `False`, the parameter ``epsilon`` specifies the
value of the entropic regularization parameter. When `True`, ``epsilon``
refers to a fraction of the :attr:`std_cost_matrix`, which is computed
adaptively from data. Can also be set to ``mean`` or ``std`` to use mean
of cost matrix if necessary.
relative_epsilon: when :obj:`False`, the parameter ``epsilon`` specifies the
value of the entropic regularization parameter. When :obj:`True` or set
to a string, ``epsilon`` refers to a fraction of the
:attr:`std_cost_matrix` or :attr:`mean_cost_matrix`, which is computed
adaptively from data, depending on whether it is set to ``mean`` or
``std``.
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean', 'std' and 'max_cost'. Alternatively, a float factor can
be given to rescale the cost such that ``cost_matrix /= scale_cost``.
Expand Down
2 changes: 1 addition & 1 deletion src/ott/math/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def norm(
is None. If both `axis` and `ord` are None, the 2-norm of ``x.ravel``
will be returned.
ord: `{non-zero int, jnp.inf, -jnp.inf, 'fro', 'nuc'}`, Order of the norm.
The default is `None`, which is equivalent to `2.0` for vectors.
The default is `None`, which is equivalent to `2` for vectors.
axis: `{None, int, 2-tuple of ints}`, optional. If `axis` is an integer, it
specifies the axis of `x` along which to compute the vector norms.
If `axis` is a 2-tuple, it specifies the axes that hold 2-D matrices, and
Expand Down
1 change: 1 addition & 0 deletions src/ott/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
sinkhorn_divergence,
sliced,
soft_sort,
unreg,
)
90 changes: 90 additions & 0 deletions src/ott/tools/unreg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import NamedTuple, Optional, Tuple

import jax.experimental.sparse as jesp
import jax.numpy as jnp

from optax import assignment

from ott.geometry import costs, geometry, pointcloud

__all__ = ["hungarian"]


class HungarianOutput(NamedTuple):
r"""Output of the Hungarian solver.
Args:
geom: geometry object
paired_indices: Array of shape ``[2, n]``, of :math:`n` pairs
of indices, for which the optimal transport assigns mass. Namely, for each
index :math:`0 \leq k < n`, if one has
:math:`i := \text{paired_indices}[0, k]` and
:math:`j := \text{paired_indices}[1, k]`, then point :math:`i` in
the first geometry sends mass to point :math:`j` in the second.
"""
geom: geometry.Geometry
paired_indices: Optional[jnp.ndarray] = None

@property
def matrix(self) -> jesp.BCOO:
"""``[n, n]`` transport matrix in sparse format, with ``n`` NNZ entries."""
n, _ = self.geom.shape
unit_mass = jnp.ones((n,)) / n
indices = self.paired_indices.swapaxes(0, 1)
return jesp.BCOO((unit_mass, indices), shape=(n, n))


def hungarian(geom: geometry.Geometry) -> Tuple[jnp.ndarray, HungarianOutput]:
"""Solve matching problem using :term:`Hungarian algorithm` from :mod:`optax`.
Args:
geom: Geometry object with square (shape ``[n,n]``)
:attr:`~ott.geometry.geometry.Geomgetry.cost matrix`.
Returns:
The value of the unregularized OT problem, along with an output
object listing relevant information on outputs.
"""
n, m = geom.shape
assert n == m, f"Hungarian can only match same # of points, got {n} and {m}."
i, j = assignment.hungarian_algorithm(geom.cost_matrix)

hungarian_out = HungarianOutput(geom=geom, paired_indices=jnp.stack((i, j)))
return jnp.sum(geom.cost_matrix[i, j]) / n, hungarian_out


def wassdis_p(x: jnp.ndarray, y: jnp.ndarray, p: float = 2.0) -> float:
"""Compute the :term:`Wasserstein distance`, uses :term:`Hungarian algorithm`.
Uses :func:`hungarian` to solve the :term:`optimal matching problem` between
two point clouds of the same size, to compute a :term:`Wasserstein distance`
estimator.
Note:
At the moment, only supports point clouds of the same size to be easily
cast as an optimal matching problem.
Args:
x: ``[n,d]`` point cloud
y: ``[n,d]`` point cloud of the same size
p: order of the Wasserstein distance, non-negative float.
Returns:
The p-Wasserstein distance between these point clouds.hungarian
"""
geom = pointcloud.PointCloud(x, y, cost_fn=costs.EuclideanP(p))
cost, _ = hungarian(geom)
return cost ** 1. / p
60 changes: 60 additions & 0 deletions tests/tools/unreg_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple

import pytest

import jax
import jax.numpy as jnp
import numpy as np

from ott.geometry import costs, pointcloud
from ott.solvers import linear
from ott.tools import unreg


class TestHungarian:

@pytest.mark.parametrize("cost_fn", [costs.PNormP(1.3), None])
def test_matches_sink(self, rng: jax.Array, cost_fn: Optional[costs.CostFn]):
n, m, dim = 12, 12, 5
rng1, rng2 = jax.random.split(rng, 2)
x, y = gen_data(rng1, n, m, dim)
geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=.0005)
cost_hung, out_hung = jax.jit(unreg.hungarian)(geom)
out_sink = jax.jit(linear.solve)(geom)
np.testing.assert_allclose(
out_sink.primal_cost, cost_hung, rtol=1e-3, atol=1e-3
)
np.testing.assert_allclose(
out_sink.matrix, out_hung.matrix.todense(), rtol=1e-3, atol=1e-3
)

@pytest.mark.parametrize("p", [1.3, 2.3])
def test_wass(self, rng: jax.Array, p: float):
n, m, dim = 12, 12, 5
rng1, rng2 = jax.random.split(rng, 2)
x, y = gen_data(rng1, n, m, dim)
geom = pointcloud.PointCloud(x, y, cost_fn=costs.EuclideanP(p=p))
cost_hung, _ = jax.jit(unreg.hungarian)(geom)
w_p = jax.jit(unreg.wassdis_p)(x, y, p)
np.testing.assert_allclose(w_p, cost_hung ** 1. / p, rtol=1e-3, atol=1e-3)


def gen_data(rng: jax.Array, n: int, m: int,
dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
rngs = jax.random.split(rng, 4)
x = jax.random.uniform(rngs[0], (n, dim))
y = jax.random.uniform(rngs[1], (m, dim))
return x, y

0 comments on commit b479e5f

Please sign in to comment.