From 8cfe5afefb6d6a58921003a080b8d3ee4b9f7c61 Mon Sep 17 00:00:00 2001 From: marcocuturi Date: Sat, 1 Jul 2023 13:22:40 +0200 Subject: [PATCH 1/7] add quantiles --- src/ott/solvers/linear/sinkhorn.py | 4 +- src/ott/tools/soft_sort.py | 129 +++++++++++++++++++++++------ tests/tools/soft_sort_test.py | 16 ++++ 3 files changed, 122 insertions(+), 27 deletions(-) diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index ade5000ba..a1eaf17d5 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -732,7 +732,9 @@ class Sinkhorn: gradients have been stopped. This is useful when carrying out first order differentiation, and is only valid (as with ``implicit_differentiation``) when the algorithm has converged with a low tolerance. - initializer: how to compute the initial potentials/scalings. + initializer: how to compute the initial potentials/scalings. This refers to + a few possible classes implemented following the template in + :class:`~ott.initializers.linear.SinkhornInitializer`. progress_fn: callback function which gets called during the Sinkhorn iterations, so the user can display the error at each iteration, e.g., using a progress bar. See :func:`~ott.utils.default_progress_fn` diff --git a/src/ott/tools/soft_sort.py b/src/ott/tools/soft_sort.py index 748c4a65a..25c64df35 100644 --- a/src/ott/tools/soft_sort.py +++ b/src/ott/tools/soft_sort.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import jax import jax.numpy as jnp @@ -27,8 +27,8 @@ def transport_for_sort( inputs: jnp.ndarray, - weights: jnp.ndarray, - target_weights: jnp.ndarray, + weights: Optional[jnp.ndarray] = None, + target_weights: Optional[jnp.ndarray] = None, squashing_fun: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, epsilon: float = 1e-2, **kwargs: Any, @@ -62,14 +62,16 @@ def transport_for_sort( squashing_fun = lambda z: jax.nn.sigmoid((z - jnp.mean(z)) / (jnp.std(z) + 1e-10)) x = squashing_fun(x) + a = jnp.squeeze(weights) b = jnp.squeeze(target_weights) - num_targets = b.shape[0] + num_targets = inputs.shape[0] if b is None else b.shape[0] y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis] geom = pointcloud.PointCloud(x, y, epsilon=epsilon) prob = linear_problem.LinearProblem(geom, a=a, b=b) + # Ensure the default initializer is "sorting" if default uniform weights. solver = sinkhorn.Sinkhorn(**kwargs) return solver(prob) @@ -124,7 +126,10 @@ def _sort( jnp.ones(topk, dtype=inputs.dtype) / num_points ]) else: - num_targets = num_points if num_targets is None else num_targets + if num_targets is None or num_targets == num_points: + num_targets = num_points + # use sorting initializer in this case. + kwargs.setdefault("initializer", "sorting") start_index = 0 b = jnp.ones((num_targets,)) / num_targets ot = transport_for_sort(inputs, a, b, **kwargs) @@ -215,50 +220,122 @@ def quantile( inputs: jnp.ndarray, axis: int = -1, level: float = 0.5, - weight: float = 0.05, + weight: Optional[float] = None, + **kwargs: Any, +) -> jnp.ndarray: + """Compute a soft quantile on the input tensor.""" + return quantiles( + inputs, axis, levels=jnp.atleast_1d(level), weight=weight, **kwargs + ) + + +def quantiles( + inputs: jnp.ndarray, + axis: int = -1, + levels: Optional[jnp.ndarray] = None, + weight: Optional[Union[float, jnp.ndarray]] = None, **kwargs: Any, ) -> jnp.ndarray: r"""Apply the soft quantile operator on the input tensor. For instance: + ``` x = jax.random.uniform(rng, (1000,)) - q = quantile(x, level=0.5, weight=0.01) + q = quantiles(x, level=jnp.array([0.2, 0.8], weight=0.01) + ``` - Then q will be computed as a mean over the 10 median points of x. - Therefore, there is a trade-off between accuracy and gradient. + In that case, ``q`` will not hold the 20-th and 80-th percentile in ``x``, but + rather a convex combination (a weighted mean, with weights summing to 1) of + all values in ``x``, that approximates such percentiles. These values offer + a trade-off between accuracy (closeness to the true median) and gradient (the + differentiation of ``q`` will impact all values listed in ``x``). Args: inputs: a jnp.ndarray of any shape. axis: the axis on which to apply the operator. - level: the value of the quantile level to be computed. 0.5 for median. - weight: the weight of the quantile in the transport problem. + levels: values of the quantile level to be computed, e.g. [0.5] for median. + should be >0.0 and <1.0. Selected as [0.2, 0.5, 0.8] by default. + weight: the weight assigned to each quantile target value in the OT problem. + Note: Since the number of quantiles times that weight must be strictly + smaller than 0, in order to leave enough mass to set other target values + in the transport problem, the algorithm ensures this by selecting if needed, + a lower value. kwargs: keyword arguments passed on to lower level functions. Of interest - to the user are ``squashing_fun``, which will redistribute the values in - ``inputs`` to lie in [0,1] (sigmoid of whitened values by default) to - solve the optimal transport problem; ``cost_fn``, used in ``PointCloud``, - that defines the ground cost function to transport from ``inputs`` to the - ``num_targets`` target values (squared Euclidean distance by default, see - ``pointcloud.py`` for more details); ``epsilon`` values as well as other - parameters to shape the ``sinkhorn`` algorithm. + to the user are ``squashing_fun``, which will redistribute the values in + ``inputs`` to lie in [0,1] (sigmoid of whitened values by default) to + solve the optimal transport problem; ``cost_fn``, used in ``PointCloud``, + that defines the ground cost function to transport from ``inputs`` to the + ``num_targets`` target values (squared Euclidean distance by default, see + ``pointcloud.py`` for more details); ``epsilon`` values as well as other + parameters to shape the ``sinkhorn`` algorithm. Returns: - A jnp.ndarray, which has the same shape as the input, except on the give + A jnp.ndarray, which has the same shape as the input, except on the given axis on which the dimension is 1. """ def _quantile( - inputs: jnp.ndarray, level: float, weight: float, **kwargs + inputs: jnp.ndarray, levels: float, weight: float, **kwargs ) -> jnp.ndarray: - # TODO(cuturi,oliviert) option to compute several quantiles at once num_points = inputs.shape[0] + num_quantiles = levels.shape[0] a = jnp.ones((num_points,)) / num_points - b = jnp.array([level - weight / 2, weight, 1.0 - weight / 2 - level]) - ot = transport_for_sort(inputs, a, b, **kwargs) - out = 1.0 / b * ot.apply(jnp.squeeze(inputs), axis=0) - return out[1:2] + levels = jnp.array([0.2, 0.5, 0.8]) if levels is None else levels + idx = jnp.argsort(levels) + levels = levels[idx] + + extended_levels = jnp.concatenate([ + jnp.array([0.0]), levels, jnp.array([1.0]) + ]) + filler_weights = extended_levels[1:] - extended_levels[:-1] + safe_weight = 0.5 * jnp.concatenate([ + jnp.array([1.0 / num_quantiles]), filler_weights + ]) + if weight is None: + # Populate with other options. + safe_weight = jnp.concatenate([ + safe_weight, + jnp.array( + [.02] + ), # reasonable mass per quantile for a small number of points + jnp.array( + [1.5 / num_points] + ), # this means each quantile would be ~ assigned 1.5 points. + ]) + else: + safe_weight = jnp.concatenate([safe_weight, jnp.atleast_1d(weight)]) + weight = jnp.min(safe_weight) + weights = jnp.ones(filler_weights.shape) * weight + + # Takes into account quantile_width in the definition of weights + shift = -jnp.ones(filler_weights.shape) + shift = shift + 0.5 * ( + jax.nn.one_hot(0, num_quantiles + 1) + + jax.nn.one_hot(num_quantiles, num_quantiles + 1) + ) + filler_weights = filler_weights + weights * shift + + # Adds one more value to have tensors of the same shape to interleave them. + quantile_weights = jnp.ones(num_quantiles + 1) * weights + + # Interleaves the filler_weights with the quantile weights. + weights = jnp.reshape( + jnp.stack([filler_weights, quantile_weights], axis=1), (-1,) + )[:-1] + + ot = transport_for_sort(inputs, a, weights, **kwargs) + out = 1.0 / weights * ot.apply(jnp.squeeze(inputs), axis=0) + + # Recover odd indices corresponding to the desired quantiles. + odds = jnp.concatenate([ + jnp.zeros((num_quantiles + 1, 1), dtype=bool), + jnp.ones((num_quantiles + 1, 1), dtype=bool) + ], + axis=1).ravel()[:-1] + return (out[odds])[idx] - return apply_on_axis(_quantile, inputs, axis, level, weight, **kwargs) + return apply_on_axis(_quantile, inputs, axis, levels, weight, **kwargs) def _quantile_normalization( diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index f689e3595..ed56a647f 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -117,6 +117,22 @@ def test_quantile_on_several_axes(self, rng: jax.random.PRNGKeyArray): q, 0.5 * np.ones((batch, 1, channels)), atol=3e-2 ) + @pytest.mark.fast() + def test_quantiles(self): + inputs = jax.random.uniform(jax.random.PRNGKey(0), (200, 2, 3)) + + level = .5 + m1 = soft_sort.quantile(inputs, level=level, weight=None, axis=0) + np.testing.assert_approx_equal(m1.mean(), level, significant=2) + m2 = soft_sort.quantile(inputs, level=level, weight=.01, axis=0) + np.testing.assert_approx_equal(m2.mean(), level, significant=2) + + levels = jnp.array([.1, .8, .4]) + m1 = soft_sort.quantiles(inputs, levels=levels, weight=None, axis=0) + np.testing.assert_allclose(m1.mean(axis=[1, 2]), levels, atol=5e-2) + m2 = soft_sort.quantiles(inputs, levels=levels, weight=None, axis=0) + np.testing.assert_allclose(m2.mean(axis=[1, 2]), levels, atol=5 - 2) + def test_soft_quantile_normalization(self, rng: jax.random.PRNGKeyArray): rngs = jax.random.split(rng, 2) x = jax.random.uniform(rngs[0], shape=(100,)) From 1ddcfe3fc5aa385bf76dae92e0984204f615b56b Mon Sep 17 00:00:00 2001 From: marcocuturi Date: Sun, 2 Jul 2023 00:42:31 +0200 Subject: [PATCH 2/7] numpy --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f3da10bd1..f07e6e743 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,8 +16,7 @@ dependencies = [ "jax>=0.1.67", "jaxlib>=0.1.47", "jaxopt>=0.5.5", - # https://github.com/google/jax/discussions/9951#discussioncomment-3017784 - "numpy>=1.18.4, !=1.23.0", + "numpy>=1.18.0", "flax>=0.5.2", "optax>=0.1.1", "lineax>=0.0.1; python_version >= '3.9'" From 872e3af5ae4bbf69a5af1d0a9de0dc52dbace020 Mon Sep 17 00:00:00 2001 From: marcocuturi Date: Sun, 2 Jul 2023 14:58:45 +0200 Subject: [PATCH 3/7] using Michal's fix --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f07e6e743..53ae77294 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,8 @@ dependencies = [ "jax>=0.1.67", "jaxlib>=0.1.47", "jaxopt>=0.5.5", - "numpy>=1.18.0", + ## https://github.com/google/jax/discussions/9951#discussioncomment-3017784 + "numpy>=1.18.4, !=1.25.0", "flax>=0.5.2", "optax>=0.1.1", "lineax>=0.0.1; python_version >= '3.9'" From 6a95cc4401c54d33e9adbb15cffd06946e30287c Mon Sep 17 00:00:00 2001 From: marcocuturi Date: Mon, 3 Jul 2023 11:02:37 +0200 Subject: [PATCH 4/7] chg threshold in kmeans test --- tests/tools/k_means_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tools/k_means_test.py b/tests/tools/k_means_test.py index 91c56839f..4eba95e99 100644 --- a/tests/tools/k_means_test.py +++ b/tests/tools/k_means_test.py @@ -100,7 +100,7 @@ def test_matches_sklearn(self, rng: jax.random.PRNGKeyArray, k: int): pred_centers.min(axis=0) >= geom.x.min(axis=0), True ) # the largest was 70.56378 - assert jnp.abs(pred_inertia - gt_inertia) <= 100 + assert jnp.abs(pred_inertia - gt_inertia) <= 200 def test_initialization_differentiable(self, rng: jax.random.PRNGKeyArray): From 66b4b122d9b9ea5745782324725d0a4e4613ead7 Mon Sep 17 00:00:00 2001 From: marcocuturi Date: Mon, 3 Jul 2023 12:26:56 +0200 Subject: [PATCH 5/7] changing quantile API to match jnp's + pydocs --- src/ott/tools/soft_sort.py | 262 +++++++++++++++++++++------------- tests/tools/soft_sort_test.py | 29 ++-- 2 files changed, 169 insertions(+), 122 deletions(-) diff --git a/src/ott/tools/soft_sort.py b/src/ott/tools/soft_sort.py index 25c64df35..8d9e9bf5a 100644 --- a/src/ott/tools/soft_sort.py +++ b/src/ott/tools/soft_sort.py @@ -71,7 +71,6 @@ def transport_for_sort( geom = pointcloud.PointCloud(x, y, epsilon=epsilon) prob = linear_problem.LinearProblem(geom, a=a, b=b) - # Ensure the default initializer is "sorting" if default uniform weights. solver = sinkhorn.Sinkhorn(**kwargs) return solver(prob) @@ -126,6 +125,7 @@ def _sort( jnp.ones(topk, dtype=inputs.dtype) / num_points ]) else: + # Use the "sorting" initializer if default uniform weights of same size. if num_targets is None or num_targets == num_points: num_targets = num_points # use sorting initializer in this case. @@ -146,28 +146,50 @@ def sort( ) -> jnp.ndarray: r"""Apply the soft sort operator on a given axis of the input. + For instance: + + ``` + x = jax.random.uniform(rng, (100,)) + x_sorted = sort(x) + ``` + + will output sorted convex-combinations of values contained in ``x``, that are + differentiable approximations to the sorted vector of entries in ``x``. + These should be the values produced by :func:`jax.numpy.sort`, + + ``` + x_ranks = jax.numpy.sort(x) + ``` + + Args: inputs: jnp.ndarray of any shape. - axis: the axis on which to apply the operator. + axis: the axis on which to apply the soft-sorting operator. topk: if set to a positive value, the returned vector will only contain - the top-k values. This also reduces the complexity of soft sorting. - num_targets: if top-k is not specified, num_targets defines the number of - (composite) sorted values computed from the inputs (each value is a convex - combination of values recorded in the inputs, provided in increasing - order). If not specified, ``num_targets`` is set by default to be the size - of the slices of the input that are sorted, i.e. the number of composite - sorted values is equal to that of the inputs that are sorted. + the top-k values. This also reduces the complexity of soft-sorting, since + the number of target points to which the slice of the ``inputs`` tensor + will be mapped to will be equal to ``topk+1``. + num_targets: if ``topk`` is not specified, ``num_targets`` defines the + number of (composite) sorted values computed from the inputs (each value + is a convex combination of values recorded in the inputs, provided in + increasing order). If neither ``topk`` nor ``num_targets`` are specified, + ``num_targets`` defaults to the size of the slices of the input that are + sorted, i.e. ``inputs.shape[axis]``, and the number of composite sorted + values is equal to the slice of the inputs that are sorted. kwargs: keyword arguments passed on to lower level functions. Of interest to the user are ``squashing_fun``, which will redistribute the values in - ``inputs`` to lie in [0,1] (sigmoid of whitened values by default) to - solve the optimal transport problem; ``cost_fn``, used in ``PointCloud``, - that defines the ground cost function to transport from ``inputs`` to the - ``num_targets`` target values (squared Euclidean distance by default, see - ``pointcloud.py`` for more details); ``epsilon`` values as well as other - parameters to shape the ``sinkhorn`` algorithm. + ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) + to solve the optimal transport problem; + attribute :attr:`~ott.geometry.pointcloud.cost_fn` of + :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground + cost function to transport from ``inputs`` to the ``num_targets`` target + values (:class:`~ott.geometry.costs.SqEuclidean` by default, see + :class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization + parameter. Remaining ``kwargs`` are passed on to defined the + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. Returns: - A jnp.ndarray of the same shape as the input with soft sorted values on the + A jnp.ndarray of the same shape as the input with soft-sorted values on the given axis. """ return apply_on_axis(_sort, inputs, axis, topk, num_targets, **kwargs) @@ -192,26 +214,52 @@ def ranks( ) -> jnp.ndarray: r"""Apply the soft rank operator on input tensor. + For instance: + + ``` + x = jax.random.uniform(rng, (100,)) + x_ranks = ranks(x) + ``` + + will output fractional values, between 0 and 1, that are differentiable + approximations to the normalized ranks of entries in ``x``. These should be + compared to the non-differentiable rank vectors, namely the normalized inverse + permutation produced by :func:`jax.numpy.argsort`, which can be obtained as: + + ``` + x_ranks = jax.numpy.argsort(jax.numpy.argsort(x)) / x.shape[0] + ``` + Args: - inputs: a jnp.ndarray of any shape. - axis: the axis on which to apply the soft ranks operator. - num_targets: num_targets defines the number of targets used to compute a - composite ranks for each value in ``inputs``: that soft rank will be a - convex combination of values in [0,...,``(num_targets-2)/num_targets``,1] - specified by the optimal transport between values in ``inputs`` towards - those values. If not specified, ``num_targets`` is set by default to be - the size of the slices of the input that are sorted. + inputs: jnp.ndarray of any shape. + axis: the axis on which to apply the soft-sorting operator. + topk: if set to a positive value, the returned vector will only contain + the top-k values. This also reduces the complexity of soft-sorting, since + the number of target points to which the slice of the ``inputs`` tensor + will be mapped to will be equal to ``topk+1``. + num_targets: if ``topk`` is not specified, ``num_targets`` defines the + number of (composite) sorted values computed from the inputs (each value + is a convex combination of values recorded in the inputs, provided in + increasing order). If neither ``topk`` nor ``num_targets`` are specified, + ``num_targets`` defaults to the size of the slices of the input that are + sorted, i.e. ``inputs.shape[axis]``, and the number of composite sorted + values is equal to the slice of the inputs that are sorted. kwargs: keyword arguments passed on to lower level functions. Of interest to the user are ``squashing_fun``, which will redistribute the values in - ``inputs`` to lie in [0,1] (sigmoid of whitened values by default) to - solve the optimal transport problem; ``cost_fn``, used in ``PointCloud``, - that defines the ground cost function to transport from ``inputs`` to the - ``num_targets`` target values (squared Euclidean distance by default, see - ``pointcloud.py`` for more details); ``epsilon`` values as well as other - parameters to shape the ``sinkhorn`` algorithm. + ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) + to solve the optimal transport problem; + attribute :attr:`~ott.geometry.pointcloud.cost_fn` of + :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground + cost function to transport from ``inputs`` to the ``num_targets`` target + values (:class:`~ott.geometry.costs.SqEuclidean` by default, see + :class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization + parameter. Remaining ``kwargs`` are passed on to defined the + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. Returns: - A jnp.ndarray of the same shape as inputs, with the ranks. + A jnp.ndarray of the same shape as the input with soft-rank values + normalized to be in :math:`[0,1]` that replace the orignal ones. + """ return apply_on_axis(_ranks, inputs, axis, num_targets, **kwargs) @@ -219,76 +267,75 @@ def ranks( def quantile( inputs: jnp.ndarray, axis: int = -1, - level: float = 0.5, - weight: Optional[float] = None, - **kwargs: Any, -) -> jnp.ndarray: - """Compute a soft quantile on the input tensor.""" - return quantiles( - inputs, axis, levels=jnp.atleast_1d(level), weight=weight, **kwargs - ) - - -def quantiles( - inputs: jnp.ndarray, - axis: int = -1, - levels: Optional[jnp.ndarray] = None, + q: Optional[jnp.ndarray] = None, weight: Optional[Union[float, jnp.ndarray]] = None, **kwargs: Any, ) -> jnp.ndarray: - r"""Apply the soft quantile operator on the input tensor. + r"""Apply the soft quantiles operator on the input tensor. For instance: ``` - x = jax.random.uniform(rng, (1000,)) - q = quantiles(x, level=jnp.array([0.2, 0.8], weight=0.01) + x = jax.random.uniform(rng, (100,)) + x_quantiles = quantiles(x, q=jnp.array([0.2, 0.8])) ``` - In that case, ``q`` will not hold the 20-th and 80-th percentile in ``x``, but - rather a convex combination (a weighted mean, with weights summing to 1) of - all values in ``x``, that approximates such percentiles. These values offer - a trade-off between accuracy (closeness to the true median) and gradient (the - differentiation of ``q`` will impact all values listed in ``x``). + ``x_quantiles`` will hold an approximation to the 20-th and 80-th + percentiles in ``x``, computed as a convex combination + (a weighted mean, with weights summing to 1) of all values in ``x`` (and not, + as would be the usual approach, the 20th and 80th values of ``x`` sorted in + ascending order). These values offer a trade-off between accuracy + (closeness to the true percentiles) and gradient (the Jacobian of + ``x_quantiles`` w.r.t ``x`` will impact all values listed in ``x``, not just + the 20th and 80th). + + The non-differentiable version is given by :func:`jax.numpy.quantile`, e.g. + ``` + x_quantiles = jax.numpy.quantile(x, q=jnp.array([0.2, 0.8])) + ``` Args: inputs: a jnp.ndarray of any shape. axis: the axis on which to apply the operator. - levels: values of the quantile level to be computed, e.g. [0.5] for median. - should be >0.0 and <1.0. Selected as [0.2, 0.5, 0.8] by default. + q: values of the quantile level to be computed, e.g. [0.5] for median. + These values should all lie in :math:`[0,1]` and are selected as + ``[0.2, 0.5, 0.8]`` by default. weight: the weight assigned to each quantile target value in the OT problem. - Note: Since the number of quantiles times that weight must be strictly - smaller than 0, in order to leave enough mass to set other target values - in the transport problem, the algorithm ensures this by selecting if needed, - a lower value. + This weight should be small, typically of the order of ``1/n``, where ``n`` + is the size of ``x``. Note: Since the size of ``q`` times ``weight`` + must be strictly smaller than ``1``, in order to leave enough mass to set + other target values in the transport problem, the algorithm might ensure + this by setting, when needed, a lower value. kwargs: keyword arguments passed on to lower level functions. Of interest - to the user are ``squashing_fun``, which will redistribute the values in - ``inputs`` to lie in [0,1] (sigmoid of whitened values by default) to - solve the optimal transport problem; ``cost_fn``, used in ``PointCloud``, - that defines the ground cost function to transport from ``inputs`` to the - ``num_targets`` target values (squared Euclidean distance by default, see - ``pointcloud.py`` for more details); ``epsilon`` values as well as other - parameters to shape the ``sinkhorn`` algorithm. + to the user are ``squashing_fun``, which will redistribute the values in + ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) + to solve the optimal transport problem; + attribute :attr:`~ott.geometry.pointcloud.cost_fn` of + :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground + cost function to transport from ``inputs`` to the ``num_targets`` target + values (:class:`~ott.geometry.costs.SqEuclidean` by default, see + :class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization + parameter. Remaining ``kwargs`` are passed on to defined the + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. Returns: - A jnp.ndarray, which has the same shape as the input, except on the given - axis on which the dimension is 1. + A jnp.ndarray, which has the same shape as the input, except on the ``axis`` + that is passed, which has size ``q.shape[0]`` to collect soft-quantile + values. """ def _quantile( - inputs: jnp.ndarray, levels: float, weight: float, **kwargs + inputs: jnp.ndarray, q: float, weight: float, **kwargs ) -> jnp.ndarray: num_points = inputs.shape[0] - num_quantiles = levels.shape[0] + q = jnp.array([0.2, 0.5, 0.8]) if q is None else jnp.atleast_1d(q) + num_quantiles = q.shape[0] a = jnp.ones((num_points,)) / num_points - levels = jnp.array([0.2, 0.5, 0.8]) if levels is None else levels - idx = jnp.argsort(levels) - levels = levels[idx] + idx = jnp.argsort(q) + q = q[idx] - extended_levels = jnp.concatenate([ - jnp.array([0.0]), levels, jnp.array([1.0]) - ]) - filler_weights = extended_levels[1:] - extended_levels[:-1] + extended_q = jnp.concatenate([jnp.array([0.0]), q, jnp.array([1.0])]) + filler_weights = extended_q[1:] - extended_q[:-1] safe_weight = 0.5 * jnp.concatenate([ jnp.array([1.0 / num_quantiles]), filler_weights ]) @@ -335,7 +382,7 @@ def _quantile( axis=1).ravel()[:-1] return (out[odds])[idx] - return apply_on_axis(_quantile, inputs, axis, levels, weight, **kwargs) + return apply_on_axis(_quantile, inputs, axis, q, weight, **kwargs) def _quantile_normalization( @@ -357,24 +404,32 @@ def quantile_normalization( ) -> jnp.ndarray: r"""Renormalize inputs so that its quantiles match those of targets/weights. - The idea of quantile normalization is to map the inputs to values so that the - distribution of transformed values matches the distribution of target values. - In a sense, we want to keep the inputs in the same order, but apply the values - of the target. + Quantile normalization rearranges the values in inputs to values that match + the distribution of values described in the discrete distribution ``targets`` + weighted by ``weights``. This transformation preserves the order of values + in ``inputs`` along the specified ``axis``. Args: - inputs: the inputs array of any shape. - targets: the target values of dimension 1. The targets must be sorted. - weights: if set, the weights or the target. - axis: the axis along which to apply the transformation on the inputs. + inputs: array of any shape whose values will be changed to match those in + ``targets``. + targets: sorted array (in ascending order) of dimension 1 describing a + discrete distribution. Note: the``targets`` values must be provided as + a sorted vector. + weights: vector of nonnegative weights, summing to :math:`1.0`, of the same + size as ``targets``. When not set, this defaults to the uniform + distribution. + axis: the axis along which the quantile transformation is applied. kwargs: keyword arguments passed on to lower level functions. Of interest to the user are ``squashing_fun``, which will redistribute the values in - ``inputs`` to lie in [0,1] (sigmoid of whitened values by default) to - solve the optimal transport problem; ``cost_fn``, used in ``PointCloud``, - that defines the ground cost function to transport from ``inputs`` to the - ``num_targets`` target values (squared Euclidean distance by default, see - ``pointcloud.py`` for more details); ``epsilon`` values as well as other - parameters to shape the ``sinkhorn`` algorithm. + ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) + to solve the optimal transport problem; + attribute :attr:`~ott.geometry.pointcloud.cost_fn` of + :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground + cost function to transport from ``inputs`` to the ``num_targets`` target + values (:class:`~ott.geometry.costs.SqEuclidean` by default, see + :class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization + parameter. Remaining ``kwargs`` are passed on to defined the + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. Returns: A jnp.ndarray, which has the same shape as the input, except on the give @@ -420,12 +475,15 @@ def sort_with( topk: The number of outputs to keep. kwargs: keyword arguments passed on to lower level functions. Of interest to the user are ``squashing_fun``, which will redistribute the values in - ``inputs`` to lie in [0,1] (sigmoid of whitened values by default) to - solve the optimal transport problem; ``cost_fn``, used in ``PointCloud``, - that defines the ground cost function to transport from ``inputs`` to the - ``num_targets`` target values (squared Euclidean distance by default, see - ``pointcloud.py`` for more details); ``epsilon`` values as well as other - parameters to shape the ``sinkhorn`` algorithm. + ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) + to solve the optimal transport problem; + attribute :attr:`~ott.geometry.pointcloud.cost_fn` of + :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground + cost function to transport from ``inputs`` to the ``num_targets`` target + values (:class:`~ott.geometry.costs.SqEuclidean` by default, see + :class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization + parameter. Remaining ``kwargs`` are passed on to defined the + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. Returns: A jnp.ndarray[batch | topk, dim]. @@ -451,13 +509,11 @@ def sort_with( return sort_fn(inputs) -def _quantize( - inputs: jnp.ndarray, num_levels: int, **kwargs: Any -) -> jnp.ndarray: +def _quantize(inputs: jnp.ndarray, num_q: int, **kwargs: Any) -> jnp.ndarray: """Apply the soft quantization operator on a one dimensional array.""" num_points = inputs.shape[0] a = jnp.ones((num_points,)) / num_points - b = jnp.ones((num_levels,)) / num_levels + b = jnp.ones((num_q,)) / num_q ot = transport_for_sort(inputs, a, b, **kwargs) return 1.0 / a * ot.apply(1.0 / b * ot.apply(inputs), axis=1) @@ -484,7 +540,7 @@ def quantize( Args: inputs: the inputs as a jnp.ndarray[batch, dim]. - num_levels: number of levels available to quantize the signal. + num_levels: number of q available to quantize the signal. axis: axis along which quantization is carried out. kwargs: keyword arguments passed on to lower level functions. Of interest to the user are ``squashing_fun``, which will redistribute the values in diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index ed56a647f..8e2085b4c 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -96,20 +96,18 @@ def test_rank_one_array(self, rng: jax.random.PRNGKeyArray): np.testing.assert_allclose(ranks, expected_ranks, atol=0.9, rtol=0.1) @pytest.mark.fast() - @pytest.mark.parametrize("level", [0.2, 0.9]) - def test_quantile(self, level: float): + @pytest.mark.parametrize("q", [0.2, 0.9]) + def test_quantile(self, q: float): x = jnp.linspace(0.0, 1.0, 100) - q = soft_sort.quantile( - x, level=level, weight=0.05, epsilon=1e-3, lse_mode=True - ) + x_q = soft_sort.quantile(x, q=q, weight=0.05, epsilon=1e-3, lse_mode=True) - np.testing.assert_approx_equal(q, level, significant=1) + np.testing.assert_approx_equal(x_q, q, significant=1) def test_quantile_on_several_axes(self, rng: jax.random.PRNGKeyArray): batch, height, width, channels = 16, 100, 100, 3 x = jax.random.uniform(rng, shape=(batch, height, width, channels)) q = soft_sort.quantile( - x, axis=(1, 2), level=0.5, weight=0.05, epsilon=1e-3, lse_mode=True + x, axis=(1, 2), q=0.5, weight=0.05, epsilon=1e-3, lse_mode=True ) np.testing.assert_array_equal(q.shape, (batch, 1, channels)) @@ -120,18 +118,11 @@ def test_quantile_on_several_axes(self, rng: jax.random.PRNGKeyArray): @pytest.mark.fast() def test_quantiles(self): inputs = jax.random.uniform(jax.random.PRNGKey(0), (200, 2, 3)) - - level = .5 - m1 = soft_sort.quantile(inputs, level=level, weight=None, axis=0) - np.testing.assert_approx_equal(m1.mean(), level, significant=2) - m2 = soft_sort.quantile(inputs, level=level, weight=.01, axis=0) - np.testing.assert_approx_equal(m2.mean(), level, significant=2) - - levels = jnp.array([.1, .8, .4]) - m1 = soft_sort.quantiles(inputs, levels=levels, weight=None, axis=0) - np.testing.assert_allclose(m1.mean(axis=[1, 2]), levels, atol=5e-2) - m2 = soft_sort.quantiles(inputs, levels=levels, weight=None, axis=0) - np.testing.assert_allclose(m2.mean(axis=[1, 2]), levels, atol=5 - 2) + q = jnp.array([.1, .8, .4]) + m1 = soft_sort.quantile(inputs, q=q, weight=None, axis=0) + np.testing.assert_allclose(m1.mean(axis=[1, 2]), q, atol=5e-2) + m2 = soft_sort.quantile(inputs, q=q, weight=None, axis=0) + np.testing.assert_allclose(m2.mean(axis=[1, 2]), q, atol=5e-2) def test_soft_quantile_normalization(self, rng: jax.random.PRNGKeyArray): rngs = jax.random.split(rng, 2) From 422afe459901e961c9529409d2fbef25415e1791 Mon Sep 17 00:00:00 2001 From: marcocuturi Date: Mon, 3 Jul 2023 12:28:08 +0200 Subject: [PATCH 6/7] impact chg in NB --- docs/tutorials/notebooks/soft_sort.ipynb | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/docs/tutorials/notebooks/soft_sort.ipynb b/docs/tutorials/notebooks/soft_sort.ipynb index 2a0239b0d..cf0f751ac 100644 --- a/docs/tutorials/notebooks/soft_sort.ipynb +++ b/docs/tutorials/notebooks/soft_sort.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "Dg9-8gVqjq2H" @@ -172,7 +173,7 @@ } ], "source": [ - "jnp.quantile(x, 0.5)" + "jnp.quantile(x, q=0.5)" ] }, { @@ -196,6 +197,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "mnEkXjwT-Z1C" @@ -334,6 +336,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "JKNcCJOe9Dcl" @@ -377,6 +380,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "4w4ggUy7zYQX" @@ -442,6 +446,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "jQoRgYEd-pkj" @@ -494,7 +499,7 @@ ], "source": [ "softquantile = jax.jit(soft_sort.quantile)\n", - "softquantile(x, level=0.5)" + "softquantile(x, q=0.5)" ] }, { @@ -533,6 +538,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -596,6 +602,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "4t3VrtNcmN0R" @@ -611,6 +618,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "tqsCC0tunHQh" @@ -663,6 +671,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "irSGHYZ7nWuY" @@ -709,6 +718,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "jK94muT8oAlQ" @@ -929,6 +939,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -976,6 +987,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ From 4268043e60dd6781f4c2503efc1189a8c1b5f76d Mon Sep 17 00:00:00 2001 From: marcocuturi Date: Mon, 3 Jul 2023 12:41:04 +0200 Subject: [PATCH 7/7] pydocs --- src/ott/tools/soft_sort.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/ott/tools/soft_sort.py b/src/ott/tools/soft_sort.py index 8d9e9bf5a..8f5ed10af 100644 --- a/src/ott/tools/soft_sort.py +++ b/src/ott/tools/soft_sort.py @@ -258,7 +258,7 @@ def ranks( Returns: A jnp.ndarray of the same shape as the input with soft-rank values - normalized to be in :math:`[0,1]` that replace the orignal ones. + normalized to be in :math:`[0,1]`, replacing the original ones. """ return apply_on_axis(_ranks, inputs, axis, num_targets, **kwargs) @@ -280,14 +280,13 @@ def quantile( x_quantiles = quantiles(x, q=jnp.array([0.2, 0.8])) ``` - ``x_quantiles`` will hold an approximation to the 20-th and 80-th - percentiles in ``x``, computed as a convex combination - (a weighted mean, with weights summing to 1) of all values in ``x`` (and not, - as would be the usual approach, the 20th and 80th values of ``x`` sorted in - ascending order). These values offer a trade-off between accuracy - (closeness to the true percentiles) and gradient (the Jacobian of - ``x_quantiles`` w.r.t ``x`` will impact all values listed in ``x``, not just - the 20th and 80th). + ``x_quantiles`` will hold an approximation to the 20 and 80 percentiles in + ``x``, computed as a convex combination (a weighted mean, with weights summing + to 1) of all values in ``x`` (and not, as would be the usual approach, the + values ``x_sorted[20]`` and ``x_sorted[80]`` is ``x_sorted=jnp.sort(x)``. + These values offer a trade-off between accuracy (closeness to the true + percentiles) and gradient (the Jacobian of ``x_quantiles`` w.r.t ``x`` will + impact all values listed in ``x``, not just those indexed at 20 and 80). The non-differentiable version is given by :func:`jax.numpy.quantile`, e.g. ```