diff --git a/docs/tutorials/notebooks/soft_sort.ipynb b/docs/tutorials/notebooks/soft_sort.ipynb index 2a0239b0d..cf0f751ac 100644 --- a/docs/tutorials/notebooks/soft_sort.ipynb +++ b/docs/tutorials/notebooks/soft_sort.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "Dg9-8gVqjq2H" @@ -172,7 +173,7 @@ } ], "source": [ - "jnp.quantile(x, 0.5)" + "jnp.quantile(x, q=0.5)" ] }, { @@ -196,6 +197,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "mnEkXjwT-Z1C" @@ -334,6 +336,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "JKNcCJOe9Dcl" @@ -377,6 +380,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "4w4ggUy7zYQX" @@ -442,6 +446,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "jQoRgYEd-pkj" @@ -494,7 +499,7 @@ ], "source": [ "softquantile = jax.jit(soft_sort.quantile)\n", - "softquantile(x, level=0.5)" + "softquantile(x, q=0.5)" ] }, { @@ -533,6 +538,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -596,6 +602,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "4t3VrtNcmN0R" @@ -611,6 +618,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "tqsCC0tunHQh" @@ -663,6 +671,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "irSGHYZ7nWuY" @@ -709,6 +718,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "jK94muT8oAlQ" @@ -929,6 +939,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -976,6 +987,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ diff --git a/pyproject.toml b/pyproject.toml index f3da10bd1..53ae77294 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,8 +16,8 @@ dependencies = [ "jax>=0.1.67", "jaxlib>=0.1.47", "jaxopt>=0.5.5", - # https://github.com/google/jax/discussions/9951#discussioncomment-3017784 - "numpy>=1.18.4, !=1.23.0", + ## https://github.com/google/jax/discussions/9951#discussioncomment-3017784 + "numpy>=1.18.4, !=1.25.0", "flax>=0.5.2", "optax>=0.1.1", "lineax>=0.0.1; python_version >= '3.9'" diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index ade5000ba..a1eaf17d5 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -732,7 +732,9 @@ class Sinkhorn: gradients have been stopped. This is useful when carrying out first order differentiation, and is only valid (as with ``implicit_differentiation``) when the algorithm has converged with a low tolerance. - initializer: how to compute the initial potentials/scalings. + initializer: how to compute the initial potentials/scalings. This refers to + a few possible classes implemented following the template in + :class:`~ott.initializers.linear.SinkhornInitializer`. progress_fn: callback function which gets called during the Sinkhorn iterations, so the user can display the error at each iteration, e.g., using a progress bar. See :func:`~ott.utils.default_progress_fn` diff --git a/src/ott/tools/soft_sort.py b/src/ott/tools/soft_sort.py index 748c4a65a..8f5ed10af 100644 --- a/src/ott/tools/soft_sort.py +++ b/src/ott/tools/soft_sort.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import jax import jax.numpy as jnp @@ -27,8 +27,8 @@ def transport_for_sort( inputs: jnp.ndarray, - weights: jnp.ndarray, - target_weights: jnp.ndarray, + weights: Optional[jnp.ndarray] = None, + target_weights: Optional[jnp.ndarray] = None, squashing_fun: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, epsilon: float = 1e-2, **kwargs: Any, @@ -62,9 +62,10 @@ def transport_for_sort( squashing_fun = lambda z: jax.nn.sigmoid((z - jnp.mean(z)) / (jnp.std(z) + 1e-10)) x = squashing_fun(x) + a = jnp.squeeze(weights) b = jnp.squeeze(target_weights) - num_targets = b.shape[0] + num_targets = inputs.shape[0] if b is None else b.shape[0] y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis] geom = pointcloud.PointCloud(x, y, epsilon=epsilon) @@ -124,7 +125,11 @@ def _sort( jnp.ones(topk, dtype=inputs.dtype) / num_points ]) else: - num_targets = num_points if num_targets is None else num_targets + # Use the "sorting" initializer if default uniform weights of same size. + if num_targets is None or num_targets == num_points: + num_targets = num_points + # use sorting initializer in this case. + kwargs.setdefault("initializer", "sorting") start_index = 0 b = jnp.ones((num_targets,)) / num_targets ot = transport_for_sort(inputs, a, b, **kwargs) @@ -141,28 +146,50 @@ def sort( ) -> jnp.ndarray: r"""Apply the soft sort operator on a given axis of the input. + For instance: + + ``` + x = jax.random.uniform(rng, (100,)) + x_sorted = sort(x) + ``` + + will output sorted convex-combinations of values contained in ``x``, that are + differentiable approximations to the sorted vector of entries in ``x``. + These should be the values produced by :func:`jax.numpy.sort`, + + ``` + x_ranks = jax.numpy.sort(x) + ``` + + Args: inputs: jnp.ndarray of any shape. - axis: the axis on which to apply the operator. + axis: the axis on which to apply the soft-sorting operator. topk: if set to a positive value, the returned vector will only contain - the top-k values. This also reduces the complexity of soft sorting. - num_targets: if top-k is not specified, num_targets defines the number of - (composite) sorted values computed from the inputs (each value is a convex - combination of values recorded in the inputs, provided in increasing - order). If not specified, ``num_targets`` is set by default to be the size - of the slices of the input that are sorted, i.e. the number of composite - sorted values is equal to that of the inputs that are sorted. + the top-k values. This also reduces the complexity of soft-sorting, since + the number of target points to which the slice of the ``inputs`` tensor + will be mapped to will be equal to ``topk+1``. + num_targets: if ``topk`` is not specified, ``num_targets`` defines the + number of (composite) sorted values computed from the inputs (each value + is a convex combination of values recorded in the inputs, provided in + increasing order). If neither ``topk`` nor ``num_targets`` are specified, + ``num_targets`` defaults to the size of the slices of the input that are + sorted, i.e. ``inputs.shape[axis]``, and the number of composite sorted + values is equal to the slice of the inputs that are sorted. kwargs: keyword arguments passed on to lower level functions. Of interest to the user are ``squashing_fun``, which will redistribute the values in - ``inputs`` to lie in [0,1] (sigmoid of whitened values by default) to - solve the optimal transport problem; ``cost_fn``, used in ``PointCloud``, - that defines the ground cost function to transport from ``inputs`` to the - ``num_targets`` target values (squared Euclidean distance by default, see - ``pointcloud.py`` for more details); ``epsilon`` values as well as other - parameters to shape the ``sinkhorn`` algorithm. + ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) + to solve the optimal transport problem; + attribute :attr:`~ott.geometry.pointcloud.cost_fn` of + :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground + cost function to transport from ``inputs`` to the ``num_targets`` target + values (:class:`~ott.geometry.costs.SqEuclidean` by default, see + :class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization + parameter. Remaining ``kwargs`` are passed on to defined the + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. Returns: - A jnp.ndarray of the same shape as the input with soft sorted values on the + A jnp.ndarray of the same shape as the input with soft-sorted values on the given axis. """ return apply_on_axis(_sort, inputs, axis, topk, num_targets, **kwargs) @@ -187,26 +214,52 @@ def ranks( ) -> jnp.ndarray: r"""Apply the soft rank operator on input tensor. + For instance: + + ``` + x = jax.random.uniform(rng, (100,)) + x_ranks = ranks(x) + ``` + + will output fractional values, between 0 and 1, that are differentiable + approximations to the normalized ranks of entries in ``x``. These should be + compared to the non-differentiable rank vectors, namely the normalized inverse + permutation produced by :func:`jax.numpy.argsort`, which can be obtained as: + + ``` + x_ranks = jax.numpy.argsort(jax.numpy.argsort(x)) / x.shape[0] + ``` + Args: - inputs: a jnp.ndarray of any shape. - axis: the axis on which to apply the soft ranks operator. - num_targets: num_targets defines the number of targets used to compute a - composite ranks for each value in ``inputs``: that soft rank will be a - convex combination of values in [0,...,``(num_targets-2)/num_targets``,1] - specified by the optimal transport between values in ``inputs`` towards - those values. If not specified, ``num_targets`` is set by default to be - the size of the slices of the input that are sorted. + inputs: jnp.ndarray of any shape. + axis: the axis on which to apply the soft-sorting operator. + topk: if set to a positive value, the returned vector will only contain + the top-k values. This also reduces the complexity of soft-sorting, since + the number of target points to which the slice of the ``inputs`` tensor + will be mapped to will be equal to ``topk+1``. + num_targets: if ``topk`` is not specified, ``num_targets`` defines the + number of (composite) sorted values computed from the inputs (each value + is a convex combination of values recorded in the inputs, provided in + increasing order). If neither ``topk`` nor ``num_targets`` are specified, + ``num_targets`` defaults to the size of the slices of the input that are + sorted, i.e. ``inputs.shape[axis]``, and the number of composite sorted + values is equal to the slice of the inputs that are sorted. kwargs: keyword arguments passed on to lower level functions. Of interest to the user are ``squashing_fun``, which will redistribute the values in - ``inputs`` to lie in [0,1] (sigmoid of whitened values by default) to - solve the optimal transport problem; ``cost_fn``, used in ``PointCloud``, - that defines the ground cost function to transport from ``inputs`` to the - ``num_targets`` target values (squared Euclidean distance by default, see - ``pointcloud.py`` for more details); ``epsilon`` values as well as other - parameters to shape the ``sinkhorn`` algorithm. + ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) + to solve the optimal transport problem; + attribute :attr:`~ott.geometry.pointcloud.cost_fn` of + :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground + cost function to transport from ``inputs`` to the ``num_targets`` target + values (:class:`~ott.geometry.costs.SqEuclidean` by default, see + :class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization + parameter. Remaining ``kwargs`` are passed on to defined the + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. Returns: - A jnp.ndarray of the same shape as inputs, with the ranks. + A jnp.ndarray of the same shape as the input with soft-rank values + normalized to be in :math:`[0,1]`, replacing the original ones. + """ return apply_on_axis(_ranks, inputs, axis, num_targets, **kwargs) @@ -214,51 +267,121 @@ def ranks( def quantile( inputs: jnp.ndarray, axis: int = -1, - level: float = 0.5, - weight: float = 0.05, + q: Optional[jnp.ndarray] = None, + weight: Optional[Union[float, jnp.ndarray]] = None, **kwargs: Any, ) -> jnp.ndarray: - r"""Apply the soft quantile operator on the input tensor. + r"""Apply the soft quantiles operator on the input tensor. For instance: - x = jax.random.uniform(rng, (1000,)) - q = quantile(x, level=0.5, weight=0.01) + ``` + x = jax.random.uniform(rng, (100,)) + x_quantiles = quantiles(x, q=jnp.array([0.2, 0.8])) + ``` + + ``x_quantiles`` will hold an approximation to the 20 and 80 percentiles in + ``x``, computed as a convex combination (a weighted mean, with weights summing + to 1) of all values in ``x`` (and not, as would be the usual approach, the + values ``x_sorted[20]`` and ``x_sorted[80]`` is ``x_sorted=jnp.sort(x)``. + These values offer a trade-off between accuracy (closeness to the true + percentiles) and gradient (the Jacobian of ``x_quantiles`` w.r.t ``x`` will + impact all values listed in ``x``, not just those indexed at 20 and 80). - Then q will be computed as a mean over the 10 median points of x. - Therefore, there is a trade-off between accuracy and gradient. + The non-differentiable version is given by :func:`jax.numpy.quantile`, e.g. + ``` + x_quantiles = jax.numpy.quantile(x, q=jnp.array([0.2, 0.8])) + ``` Args: inputs: a jnp.ndarray of any shape. axis: the axis on which to apply the operator. - level: the value of the quantile level to be computed. 0.5 for median. - weight: the weight of the quantile in the transport problem. + q: values of the quantile level to be computed, e.g. [0.5] for median. + These values should all lie in :math:`[0,1]` and are selected as + ``[0.2, 0.5, 0.8]`` by default. + weight: the weight assigned to each quantile target value in the OT problem. + This weight should be small, typically of the order of ``1/n``, where ``n`` + is the size of ``x``. Note: Since the size of ``q`` times ``weight`` + must be strictly smaller than ``1``, in order to leave enough mass to set + other target values in the transport problem, the algorithm might ensure + this by setting, when needed, a lower value. kwargs: keyword arguments passed on to lower level functions. Of interest to the user are ``squashing_fun``, which will redistribute the values in - ``inputs`` to lie in [0,1] (sigmoid of whitened values by default) to - solve the optimal transport problem; ``cost_fn``, used in ``PointCloud``, - that defines the ground cost function to transport from ``inputs`` to the - ``num_targets`` target values (squared Euclidean distance by default, see - ``pointcloud.py`` for more details); ``epsilon`` values as well as other - parameters to shape the ``sinkhorn`` algorithm. + ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) + to solve the optimal transport problem; + attribute :attr:`~ott.geometry.pointcloud.cost_fn` of + :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground + cost function to transport from ``inputs`` to the ``num_targets`` target + values (:class:`~ott.geometry.costs.SqEuclidean` by default, see + :class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization + parameter. Remaining ``kwargs`` are passed on to defined the + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. Returns: - A jnp.ndarray, which has the same shape as the input, except on the give - axis on which the dimension is 1. + A jnp.ndarray, which has the same shape as the input, except on the ``axis`` + that is passed, which has size ``q.shape[0]`` to collect soft-quantile + values. """ def _quantile( - inputs: jnp.ndarray, level: float, weight: float, **kwargs + inputs: jnp.ndarray, q: float, weight: float, **kwargs ) -> jnp.ndarray: - # TODO(cuturi,oliviert) option to compute several quantiles at once num_points = inputs.shape[0] + q = jnp.array([0.2, 0.5, 0.8]) if q is None else jnp.atleast_1d(q) + num_quantiles = q.shape[0] a = jnp.ones((num_points,)) / num_points - b = jnp.array([level - weight / 2, weight, 1.0 - weight / 2 - level]) - ot = transport_for_sort(inputs, a, b, **kwargs) - out = 1.0 / b * ot.apply(jnp.squeeze(inputs), axis=0) - return out[1:2] + idx = jnp.argsort(q) + q = q[idx] + + extended_q = jnp.concatenate([jnp.array([0.0]), q, jnp.array([1.0])]) + filler_weights = extended_q[1:] - extended_q[:-1] + safe_weight = 0.5 * jnp.concatenate([ + jnp.array([1.0 / num_quantiles]), filler_weights + ]) + if weight is None: + # Populate with other options. + safe_weight = jnp.concatenate([ + safe_weight, + jnp.array( + [.02] + ), # reasonable mass per quantile for a small number of points + jnp.array( + [1.5 / num_points] + ), # this means each quantile would be ~ assigned 1.5 points. + ]) + else: + safe_weight = jnp.concatenate([safe_weight, jnp.atleast_1d(weight)]) + weight = jnp.min(safe_weight) + weights = jnp.ones(filler_weights.shape) * weight + + # Takes into account quantile_width in the definition of weights + shift = -jnp.ones(filler_weights.shape) + shift = shift + 0.5 * ( + jax.nn.one_hot(0, num_quantiles + 1) + + jax.nn.one_hot(num_quantiles, num_quantiles + 1) + ) + filler_weights = filler_weights + weights * shift + + # Adds one more value to have tensors of the same shape to interleave them. + quantile_weights = jnp.ones(num_quantiles + 1) * weights + + # Interleaves the filler_weights with the quantile weights. + weights = jnp.reshape( + jnp.stack([filler_weights, quantile_weights], axis=1), (-1,) + )[:-1] + + ot = transport_for_sort(inputs, a, weights, **kwargs) + out = 1.0 / weights * ot.apply(jnp.squeeze(inputs), axis=0) - return apply_on_axis(_quantile, inputs, axis, level, weight, **kwargs) + # Recover odd indices corresponding to the desired quantiles. + odds = jnp.concatenate([ + jnp.zeros((num_quantiles + 1, 1), dtype=bool), + jnp.ones((num_quantiles + 1, 1), dtype=bool) + ], + axis=1).ravel()[:-1] + return (out[odds])[idx] + + return apply_on_axis(_quantile, inputs, axis, q, weight, **kwargs) def _quantile_normalization( @@ -280,24 +403,32 @@ def quantile_normalization( ) -> jnp.ndarray: r"""Renormalize inputs so that its quantiles match those of targets/weights. - The idea of quantile normalization is to map the inputs to values so that the - distribution of transformed values matches the distribution of target values. - In a sense, we want to keep the inputs in the same order, but apply the values - of the target. + Quantile normalization rearranges the values in inputs to values that match + the distribution of values described in the discrete distribution ``targets`` + weighted by ``weights``. This transformation preserves the order of values + in ``inputs`` along the specified ``axis``. Args: - inputs: the inputs array of any shape. - targets: the target values of dimension 1. The targets must be sorted. - weights: if set, the weights or the target. - axis: the axis along which to apply the transformation on the inputs. + inputs: array of any shape whose values will be changed to match those in + ``targets``. + targets: sorted array (in ascending order) of dimension 1 describing a + discrete distribution. Note: the``targets`` values must be provided as + a sorted vector. + weights: vector of nonnegative weights, summing to :math:`1.0`, of the same + size as ``targets``. When not set, this defaults to the uniform + distribution. + axis: the axis along which the quantile transformation is applied. kwargs: keyword arguments passed on to lower level functions. Of interest to the user are ``squashing_fun``, which will redistribute the values in - ``inputs`` to lie in [0,1] (sigmoid of whitened values by default) to - solve the optimal transport problem; ``cost_fn``, used in ``PointCloud``, - that defines the ground cost function to transport from ``inputs`` to the - ``num_targets`` target values (squared Euclidean distance by default, see - ``pointcloud.py`` for more details); ``epsilon`` values as well as other - parameters to shape the ``sinkhorn`` algorithm. + ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) + to solve the optimal transport problem; + attribute :attr:`~ott.geometry.pointcloud.cost_fn` of + :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground + cost function to transport from ``inputs`` to the ``num_targets`` target + values (:class:`~ott.geometry.costs.SqEuclidean` by default, see + :class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization + parameter. Remaining ``kwargs`` are passed on to defined the + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. Returns: A jnp.ndarray, which has the same shape as the input, except on the give @@ -343,12 +474,15 @@ def sort_with( topk: The number of outputs to keep. kwargs: keyword arguments passed on to lower level functions. Of interest to the user are ``squashing_fun``, which will redistribute the values in - ``inputs`` to lie in [0,1] (sigmoid of whitened values by default) to - solve the optimal transport problem; ``cost_fn``, used in ``PointCloud``, - that defines the ground cost function to transport from ``inputs`` to the - ``num_targets`` target values (squared Euclidean distance by default, see - ``pointcloud.py`` for more details); ``epsilon`` values as well as other - parameters to shape the ``sinkhorn`` algorithm. + ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) + to solve the optimal transport problem; + attribute :attr:`~ott.geometry.pointcloud.cost_fn` of + :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground + cost function to transport from ``inputs`` to the ``num_targets`` target + values (:class:`~ott.geometry.costs.SqEuclidean` by default, see + :class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization + parameter. Remaining ``kwargs`` are passed on to defined the + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. Returns: A jnp.ndarray[batch | topk, dim]. @@ -374,13 +508,11 @@ def sort_with( return sort_fn(inputs) -def _quantize( - inputs: jnp.ndarray, num_levels: int, **kwargs: Any -) -> jnp.ndarray: +def _quantize(inputs: jnp.ndarray, num_q: int, **kwargs: Any) -> jnp.ndarray: """Apply the soft quantization operator on a one dimensional array.""" num_points = inputs.shape[0] a = jnp.ones((num_points,)) / num_points - b = jnp.ones((num_levels,)) / num_levels + b = jnp.ones((num_q,)) / num_q ot = transport_for_sort(inputs, a, b, **kwargs) return 1.0 / a * ot.apply(1.0 / b * ot.apply(inputs), axis=1) @@ -407,7 +539,7 @@ def quantize( Args: inputs: the inputs as a jnp.ndarray[batch, dim]. - num_levels: number of levels available to quantize the signal. + num_levels: number of q available to quantize the signal. axis: axis along which quantization is carried out. kwargs: keyword arguments passed on to lower level functions. Of interest to the user are ``squashing_fun``, which will redistribute the values in diff --git a/tests/tools/k_means_test.py b/tests/tools/k_means_test.py index 91c56839f..4eba95e99 100644 --- a/tests/tools/k_means_test.py +++ b/tests/tools/k_means_test.py @@ -100,7 +100,7 @@ def test_matches_sklearn(self, rng: jax.random.PRNGKeyArray, k: int): pred_centers.min(axis=0) >= geom.x.min(axis=0), True ) # the largest was 70.56378 - assert jnp.abs(pred_inertia - gt_inertia) <= 100 + assert jnp.abs(pred_inertia - gt_inertia) <= 200 def test_initialization_differentiable(self, rng: jax.random.PRNGKeyArray): diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index f689e3595..8e2085b4c 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -96,20 +96,18 @@ def test_rank_one_array(self, rng: jax.random.PRNGKeyArray): np.testing.assert_allclose(ranks, expected_ranks, atol=0.9, rtol=0.1) @pytest.mark.fast() - @pytest.mark.parametrize("level", [0.2, 0.9]) - def test_quantile(self, level: float): + @pytest.mark.parametrize("q", [0.2, 0.9]) + def test_quantile(self, q: float): x = jnp.linspace(0.0, 1.0, 100) - q = soft_sort.quantile( - x, level=level, weight=0.05, epsilon=1e-3, lse_mode=True - ) + x_q = soft_sort.quantile(x, q=q, weight=0.05, epsilon=1e-3, lse_mode=True) - np.testing.assert_approx_equal(q, level, significant=1) + np.testing.assert_approx_equal(x_q, q, significant=1) def test_quantile_on_several_axes(self, rng: jax.random.PRNGKeyArray): batch, height, width, channels = 16, 100, 100, 3 x = jax.random.uniform(rng, shape=(batch, height, width, channels)) q = soft_sort.quantile( - x, axis=(1, 2), level=0.5, weight=0.05, epsilon=1e-3, lse_mode=True + x, axis=(1, 2), q=0.5, weight=0.05, epsilon=1e-3, lse_mode=True ) np.testing.assert_array_equal(q.shape, (batch, 1, channels)) @@ -117,6 +115,15 @@ def test_quantile_on_several_axes(self, rng: jax.random.PRNGKeyArray): q, 0.5 * np.ones((batch, 1, channels)), atol=3e-2 ) + @pytest.mark.fast() + def test_quantiles(self): + inputs = jax.random.uniform(jax.random.PRNGKey(0), (200, 2, 3)) + q = jnp.array([.1, .8, .4]) + m1 = soft_sort.quantile(inputs, q=q, weight=None, axis=0) + np.testing.assert_allclose(m1.mean(axis=[1, 2]), q, atol=5e-2) + m2 = soft_sort.quantile(inputs, q=q, weight=None, axis=0) + np.testing.assert_allclose(m2.mean(axis=[1, 2]), q, atol=5e-2) + def test_soft_quantile_normalization(self, rng: jax.random.PRNGKeyArray): rngs = jax.random.split(rng, 2) x = jax.random.uniform(rngs[0], shape=(100,))