diff --git a/docs/geometry.rst b/docs/geometry.rst index 42edcdc64..407d8fca0 100644 --- a/docs/geometry.rst +++ b/docs/geometry.rst @@ -63,6 +63,7 @@ Cost Functions costs.SqEuclidean costs.RegTICost costs.Euclidean + costs.EuclideanP costs.Cosine costs.Arccos costs.Bures diff --git a/docs/glossary.rst b/docs/glossary.rst index 918322471..903a5d992 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -162,6 +162,12 @@ Glossary evaluations of :math:`c` on various pairs of points, :math:`C=[c(x_i, y_j)]_{ij}`. + Hungarian algorithm + Combinatorial algorithm proposed by Harold Kuhn to solve the + :term:`optimal matching problem`. See the + `Wikipedia definition `__ + . + implicit differentiation Formula used to compute the vector-Jacobian product of the minimizer of an optimization procedure that leverages @@ -236,7 +242,7 @@ Glossary measurable set :math:`B`, :math:`T\#\mu(B)=\mu(T^{-1}(B))`. Intuitively, it is the measure obtained by applying the map :math:`T` to all points described in the support of :math:`\mu`. See also the - `Wikipedia definition `_. + `Wikipedia definition `__. optimal transport Theory that characterizes efficient transformations between probability @@ -245,6 +251,11 @@ Glossary whereas computational aspects become relevant when estimating such transforms from samples. + optimal matching problem + Instance of the :term:`Kantorovich problem` where both marginal weight + vectors :math:`a,b` are equal, and set both to a uniform weight vector + of the form :math:`(\tfrac{1}{n},\dots,\tfrac{1}{n})\in\mathbb{R}^n`. + Sinkhorn algorithm Fixed point iteration that solves the :term:`entropy-regularized optimal transport` problem (EOT). @@ -340,7 +351,7 @@ Glossary :term:`ground cost` function that is equal to the optimal objective reached when solving the :term:`Kantorovich problem`. The Wasserstein distance is truly a distance (in the sense that it satisfies all 3 - `metric axioms `_ - ), as long as the :term:`ground cost` is itself a distance to a power + `metric axioms `__ + ) if the :term:`ground cost` is itself a distance to a power :math:`p\leq 1`, and the :math:`p` root of the objective of the :term:`Kantorovich problem` is used. diff --git a/docs/spelling/technical.txt b/docs/spelling/technical.txt index c739ca417..6a9fe3a72 100644 --- a/docs/spelling/technical.txt +++ b/docs/spelling/technical.txt @@ -48,6 +48,7 @@ centroids checkpointing chromatin collinear +combinatorial covariance covariances dataclass @@ -165,6 +166,7 @@ transcriptome undirected univariate unnormalized +unregularized unscaled url vectorized diff --git a/docs/tools.rst b/docs/tools.rst index 7fb93b345..c6b8ec471 100644 --- a/docs/tools.rst +++ b/docs/tools.rst @@ -7,11 +7,24 @@ The :mod:`~ott.tools` package contains high level functions that build on outputs produced by lower-level components in the toolbox, such as :mod:`~ott.solvers`. -In particular, we provide user-friendly APIs to compute Sinkhorn divergences -:cite:`genevay:18,sejourne:19`, sliced Wasserstein distances :cite:`rabin:12`, -differentiable approximations to ranks and quantile functions :cite:`cuturi:19`, -and various tools to study Gaussians with the 2-Wasserstein metric -:cite:`gelbrich:90,delon:20`, etc. +In particular, we provide user-friendly APIs to unregularized OT quantities, +such as the :term:`Wasserstein distance` for two point clouds of the same size. +We also provide functions to pad efficiently point clouds when doing large scale +OT between them in parallel, implementations of the Sinkhorn +divergence :cite:`genevay:18,sejourne:19`, sliced Wasserstein distances +:cite:`rabin:12`, differentiable approximations to ranks and quantile functions +:cite:`cuturi:19`, and various tools to study Gaussians with the +2-:term:`Wasserstein distance` :cite:`gelbrich:90,delon:20`. + +Unregularized Optimal Transport +------------------------------- +.. autosummary:: + :toctree: _autosummary + + unreg.hungarian + unreg.HungarianOutput + unreg.wassdis_p + Segmented Sinkhorn ------------------ diff --git a/pyproject.toml b/pyproject.toml index 07685dc53..d1014be4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "lineax>=0.0.7", "numpy>=1.20.0", "typing_extensions; python_version <= '3.9'", + "optax>=0.2.4", ] keywords = [ "optimal transport", @@ -59,7 +60,7 @@ Changelog = "https://github.com/ott-jax/ott/releases" [project.optional-dependencies] neural = [ "flax>=0.6.6", - "optax>=0.1.1", + "optax>=0.2.4", "diffrax>=0.4.1", ] dev = [ diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index 67ab58c79..508318310 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -347,6 +347,33 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(*aux_data) +@jtu.register_pytree_node_class +class EuclideanP(TICost): + r""":math:`p`-power of Euclidean norm. + + Uses custom implementation of `norm` to avoid `NaN` values when + differentiating the norm of :math:`x-x`. + + Args: + p: Power used to raise Euclidean norm, in :math:`[1, +\infty)`. + """ + + def __init__(self, p: float): + super().__init__() + self.p = p + + def h(self, z: jnp.ndarray) -> float: # noqa: D102 + return mu.norm(z, ord=2) ** (self.p) + + def tree_flatten(self): # noqa: D102 + return (), (self.p,) + + @classmethod + def tree_unflatten(cls, aux_data, children): # noqa: D102 + del children + return cls(*aux_data) + + @jtu.register_pytree_node_class class RegTICost(TICost): r"""Regularized translation-invariant cost. diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index 79a074ee2..ab527722e 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -48,19 +48,25 @@ class Geometry: cost_matrix: Cost matrix of shape ``[n, m]``. kernel_matrix: Kernel matrix of shape ``[n, m]``. epsilon: Regularization parameter. If ``None`` and either - ``relative_epsilon = True`` or ``relative_epsilon = None``, this defaults - to the value computed in :attr:`mean_cost_matrix` / 20. If passed as a + ``relative_epsilon = True`` or ``relative_epsilon = None`` or + ``relative_epsilon = str`` where ``str`` can be either ``mean`` or ``std`` + , this value defaults to a multiple of :attr:`std_cost_matrix` + (or :attr:`mean_cost_matrix` if ``str`` is ``mean``), where that multiple + is set as ``DEFAULT_SCALE`` in ``epsilon_scheduler.py```. + If passed as a ``float``, then the regularizer that is ultimately used is either that ``float`` value (if ``relative_epsilon = False`` or ``None``) or that - ``float`` times the :attr:`mean_cost_matrix` - (if ``relative_epsilon = True``). Look for + ``float`` times the :attr:`std_cost_matrix` (if + ``relative_epsilon = True`` or ``relative_epsilon = `std```) or + :attr:`mean_cost_matrix` (if ``relative_epsilon = `mean```). Look for :class:`~ott.geometry.epsilon_scheduler.Epsilon` when passed as a scheduler. - relative_epsilon: when `False`, the parameter ``epsilon`` specifies the - value of the entropic regularization parameter. When `True`, ``epsilon`` - refers to a fraction of the :attr:`std_cost_matrix`, which is computed - adaptively from data. Can also be set to ``mean`` or ``std`` to use mean - of cost matrix if necessary. + relative_epsilon: when :obj:`False`, the parameter ``epsilon`` specifies the + value of the entropic regularization parameter. When :obj:`True` or set + to a string, ``epsilon`` refers to a fraction of the + :attr:`std_cost_matrix` or :attr:`mean_cost_matrix`, which is computed + adaptively from data, depending on whether it is set to ``mean`` or + ``std``. scale_cost: option to rescale the cost matrix. Implemented scalings are 'median', 'mean', 'std' and 'max_cost'. Alternatively, a float factor can be given to rescale the cost such that ``cost_matrix /= scale_cost``. diff --git a/src/ott/math/utils.py b/src/ott/math/utils.py index 591b2fcae..134541f17 100644 --- a/src/ott/math/utils.py +++ b/src/ott/math/utils.py @@ -77,7 +77,7 @@ def norm( is None. If both `axis` and `ord` are None, the 2-norm of ``x.ravel`` will be returned. ord: `{non-zero int, jnp.inf, -jnp.inf, 'fro', 'nuc'}`, Order of the norm. - The default is `None`, which is equivalent to `2.0` for vectors. + The default is `None`, which is equivalent to `2` for vectors. axis: `{None, int, 2-tuple of ints}`, optional. If `axis` is an integer, it specifies the axis of `x` along which to compute the vector norms. If `axis` is a 2-tuple, it specifies the axes that hold 2-D matrices, and diff --git a/src/ott/tools/__init__.py b/src/ott/tools/__init__.py index dbe84d684..b960f9941 100644 --- a/src/ott/tools/__init__.py +++ b/src/ott/tools/__init__.py @@ -20,4 +20,5 @@ sinkhorn_divergence, sliced, soft_sort, + unreg, ) diff --git a/src/ott/tools/unreg.py b/src/ott/tools/unreg.py new file mode 100644 index 000000000..1e4507726 --- /dev/null +++ b/src/ott/tools/unreg.py @@ -0,0 +1,90 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import NamedTuple, Optional, Tuple + +import jax.experimental.sparse as jesp +import jax.numpy as jnp + +from optax import assignment + +from ott.geometry import costs, geometry, pointcloud + +__all__ = ["hungarian"] + + +class HungarianOutput(NamedTuple): + r"""Output of the Hungarian solver. + + Args: + geom: geometry object + paired_indices: Array of shape ``[2, n]``, of :math:`n` pairs + of indices, for which the optimal transport assigns mass. Namely, for each + index :math:`0 \leq k < n`, if one has + :math:`i := \text{paired_indices}[0, k]` and + :math:`j := \text{paired_indices}[1, k]`, then point :math:`i` in + the first geometry sends mass to point :math:`j` in the second. + """ + geom: geometry.Geometry + paired_indices: Optional[jnp.ndarray] = None + + @property + def matrix(self) -> jesp.BCOO: + """``[n, n]`` transport matrix in sparse format, with ``n`` NNZ entries.""" + n, _ = self.geom.shape + unit_mass = jnp.ones((n,)) / n + indices = self.paired_indices.swapaxes(0, 1) + return jesp.BCOO((unit_mass, indices), shape=(n, n)) + + +def hungarian(geom: geometry.Geometry) -> Tuple[jnp.ndarray, HungarianOutput]: + """Solve matching problem using :term:`Hungarian algorithm` from :mod:`optax`. + + Args: + geom: Geometry object with square (shape ``[n,n]``) + :attr:`~ott.geometry.geometry.Geomgetry.cost matrix`. + + Returns: + The value of the unregularized OT problem, along with an output + object listing relevant information on outputs. + """ + n, m = geom.shape + assert n == m, f"Hungarian can only match same # of points, got {n} and {m}." + i, j = assignment.hungarian_algorithm(geom.cost_matrix) + + hungarian_out = HungarianOutput(geom=geom, paired_indices=jnp.stack((i, j))) + return jnp.sum(geom.cost_matrix[i, j]) / n, hungarian_out + + +def wassdis_p(x: jnp.ndarray, y: jnp.ndarray, p: float = 2.0) -> float: + """Compute the :term:`Wasserstein distance`, uses :term:`Hungarian algorithm`. + + Uses :func:`hungarian` to solve the :term:`optimal matching problem` between + two point clouds of the same size, to compute a :term:`Wasserstein distance` + estimator. + + Note: + At the moment, only supports point clouds of the same size to be easily + cast as an optimal matching problem. + + Args: + x: ``[n,d]`` point cloud + y: ``[n,d]`` point cloud of the same size + p: order of the Wasserstein distance, non-negative float. + + Returns: + The p-Wasserstein distance between these point clouds.hungarian + """ + geom = pointcloud.PointCloud(x, y, cost_fn=costs.EuclideanP(p)) + cost, _ = hungarian(geom) + return cost ** 1. / p diff --git a/tests/tools/unreg_test.py b/tests/tools/unreg_test.py new file mode 100644 index 000000000..6707158d9 --- /dev/null +++ b/tests/tools/unreg_test.py @@ -0,0 +1,60 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple + +import pytest + +import jax +import jax.numpy as jnp +import numpy as np + +from ott.geometry import costs, pointcloud +from ott.solvers import linear +from ott.tools import unreg + + +class TestHungarian: + + @pytest.mark.parametrize("cost_fn", [costs.PNormP(1.3), None]) + def test_matches_sink(self, rng: jax.Array, cost_fn: Optional[costs.CostFn]): + n, m, dim = 12, 12, 5 + rng1, rng2 = jax.random.split(rng, 2) + x, y = gen_data(rng1, n, m, dim) + geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=.0005) + cost_hung, out_hung = jax.jit(unreg.hungarian)(geom) + out_sink = jax.jit(linear.solve)(geom) + np.testing.assert_allclose( + out_sink.primal_cost, cost_hung, rtol=1e-3, atol=1e-3 + ) + np.testing.assert_allclose( + out_sink.matrix, out_hung.matrix.todense(), rtol=1e-3, atol=1e-3 + ) + + @pytest.mark.parametrize("p", [1.3, 2.3]) + def test_wass(self, rng: jax.Array, p: float): + n, m, dim = 12, 12, 5 + rng1, rng2 = jax.random.split(rng, 2) + x, y = gen_data(rng1, n, m, dim) + geom = pointcloud.PointCloud(x, y, cost_fn=costs.EuclideanP(p=p)) + cost_hung, _ = jax.jit(unreg.hungarian)(geom) + w_p = jax.jit(unreg.wassdis_p)(x, y, p) + np.testing.assert_allclose(w_p, cost_hung ** 1. / p, rtol=1e-3, atol=1e-3) + + +def gen_data(rng: jax.Array, n: int, m: int, + dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]: + rngs = jax.random.split(rng, 4) + x = jax.random.uniform(rngs[0], (n, dim)) + y = jax.random.uniform(rngs[1], (m, dim)) + return x, y