-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compute multiple soft-quantiles in one execution without using vmap
#382
Conversation
vmap
numba, via jaxopt, seems to be causing the issue. |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. Additional details and impacted files@@ Coverage Diff @@
## main #382 +/- ##
==========================================
+ Coverage 88.51% 88.53% +0.02%
==========================================
Files 52 52
Lines 5660 5679 +19
Branches 839 841 +2
==========================================
+ Hits 5010 5028 +18
Misses 530 530
- Partials 120 121 +1
|
q = jnp.array([.1, .8, .4]) | ||
m1 = soft_sort.quantile(inputs, q=q, weight=None, axis=0) | ||
np.testing.assert_allclose(m1.mean(axis=[1, 2]), q, atol=5e-2) | ||
m2 = soft_sort.quantile(inputs, q=q, weight=None, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does m2
exist since it's the same as m1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! I think this was naive byproduct of starting PR with 2 different interfaces (quantile
and quantiles
)
@@ -141,28 +146,50 @@ def sort( | |||
) -> jnp.ndarray: | |||
r"""Apply the soft sort operator on a given axis of the input. | |||
|
|||
For instance: | |||
|
|||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the python code blocks (also in other places), e.g.,:
.. code-block:: python
x = jax.random.uniform(rng, (100,))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise, it doesn't render nicely, see here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
jnp.ones((num_quantiles + 1, 1), dtype=bool) | ||
], | ||
axis=1).ravel()[:-1] | ||
return (out[odds])[idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary ()
|
||
Returns: | ||
A jnp.ndarray of the same shape as the input with soft sorted values on the | ||
A jnp.ndarray of the same shape as the input with soft-sorted values on the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's a good time to refactor and say An array of the same shape ...
targets: sorted array (in ascending order) of dimension 1 describing a | ||
discrete distribution. Note: the``targets`` values must be provided as | ||
a sorted vector. | ||
weights: vector of nonnegative weights, summing to :math:`1.0`, of the same |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
non-negative
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's an argument for dropping the hyphen :) https://math.stackexchange.com/questions/3342643/nonnegative-vs-non-negative
targets: sorted array (in ascending order) of dimension 1 describing a | ||
discrete distribution. Note: the``targets`` values must be provided as | ||
a sorted vector. | ||
weights: vector of nonnegative weights, summing to :math:`1.0`, of the same |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:math:`1`
inputs: array of any shape whose values will be changed to match those in | ||
``targets``. | ||
targets: sorted array (in ascending order) of dimension 1 describing a | ||
discrete distribution. Note: the``targets`` values must be provided as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the ``targets`` (missing space, doesn't render correctly)
specified by the optimal transport between values in ``inputs`` towards | ||
those values. If not specified, ``num_targets`` is set by default to be | ||
the size of the slices of the input that are sorted. | ||
inputs: jnp.ndarray<float> of any shape. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Array of any shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also elsewhere, if possible.
num_points = inputs.shape[0] | ||
q = jnp.array([0.2, 0.5, 0.8]) if q is None else jnp.atleast_1d(q) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are defaults needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, pure implementation question: why pass these as an array, not e.g., as a tuple? To be able to differentiate w.r.t. to it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's also a question of being able to re-run by changing quantile values, without jitting again, pending that the number of quantiles does not change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you prefer no defaults? it was there because there was a default previously (the median) but maybe you're right, better stick to jax's quantile
API
This PR is a follow up to
#373
This implements a
quantiles
function in thesoft_sort
module to return simultenaously multiple quantile values. Should be more efficient than the vmap proposed in the discussion, but will likely return very slightly different results.