Skip to content

Commit 7b967cb

Browse files
committed
Add weighted quantile and percentile support with tests
1 parent 5f881bf commit 7b967cb

File tree

2 files changed

+53
-67
lines changed

2 files changed

+53
-67
lines changed

jax/_src/numpy/reductions.py

Lines changed: 44 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@
2525

2626
import jax
2727
from jax import lax
28+
from jax import numpy as jnp
2829
from jax._src import api
2930
from jax._src import core
3031
from jax._src import deprecations
3132
from jax._src import dtypes
3233
from jax._src.numpy.util import (
33-
_broadcast_to, check_arraylike, _complex_elem_type,
34+
_broadcast_to, check_arraylike, ensure_arraylike, _complex_elem_type,
3435
promote_dtypes_inexact, promote_dtypes_numeric, _where)
3536
from jax._src.lax import lax as lax_internal
3637
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg
@@ -2379,6 +2380,10 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
23792380
Array([2., 4., 7.], dtype=float32)
23802381
"""
23812382
check_arraylike("quantile", a, q)
2383+
if weights is None:
2384+
a, q = ensure_arraylike("quantile", a, q)
2385+
else:
2386+
a, q, weights = ensure_arraylike("quantile", a, q, weights)
23822387
if overwrite_input or out is not None:
23832388
raise ValueError("jax.numpy.quantile does not support overwrite_input=True "
23842389
"or out != None")
@@ -2435,6 +2440,10 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24352440
Array([1.5, 3. , 4.5], dtype=float32)
24362441
"""
24372442
check_arraylike("nanquantile", a, q)
2443+
if weights is None:
2444+
a, q = ensure_arraylike("nanquantile", a, q)
2445+
else:
2446+
a, q, weights = ensure_arraylike("nanquantile", a, q, weights)
24382447
if overwrite_input or out is not None:
24392448
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
24402449
"out != None")
@@ -2445,10 +2454,9 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24452454
return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True, weights)
24462455

24472456
def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
2448-
method: str, keepdims: bool, squash_nans: bool, weights: ArrayLike | None = None) -> Array:
2457+
method: str, keepdims: bool, squash_nans: bool, weights: Array | None = None) -> Array:
24492458
if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]:
24502459
raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest', or 'inverted_cdf'")
2451-
a, = promote_dtypes_inexact(a)
24522460
keepdim = []
24532461
if dtypes.issubdtype(a.dtype, np.complexfloating):
24542462
raise ValueError("quantile does not support complex input, as the operation is poorly defined.")
@@ -2477,9 +2485,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24772485
axis = _canonicalize_axis(-1, a.ndim)
24782486
else:
24792487
axis = _canonicalize_axis(axis, a.ndim)
2480-
2481-
# Ensure q is an array and inexact
2482-
q = lax.asarray(q)
2488+
24832489
q, = promote_dtypes_inexact(q)
24842490

24852491
q_shape = q.shape
@@ -2489,8 +2495,10 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24892495

24902496
a_shape = a.shape
24912497
# Handle weights
2492-
if weights is not None:
2493-
a, weights = promote_dtypes_inexact(a, weights)
2498+
if weights is None:
2499+
a, = promote_dtypes_inexact(a)
2500+
else:
2501+
a, q = promote_dtypes_inexact(a, q)
24942502
a_shape = a.shape
24952503
w_shape = np.shape(weights)
24962504
if w_shape != a_shape:
@@ -2502,28 +2510,13 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25022510
raise ValueError("Length of weights not compatible with specified axis.")
25032511
resh = [1] * a.ndim
25042512
resh[axis] = w_shape[0]
2505-
weights = lax.reshape(lax_internal.asarray(weights), tuple(resh))
2513+
weights = lax.expand_dims(weights, axis)
25062514
weights = _broadcast_to(weights, a.shape)
25072515

2508-
if isinstance(weights, core.Tracer):
2509-
weights_arr = None
2510-
else:
2511-
try:
2512-
weights_arr = np.asarray(weights)
2513-
except Exception:
2514-
weights_arr = None
2515-
2516-
if weights_arr is not None:
2517-
if np.any(weights_arr < 0):
2518-
raise ValueError("Weights must be non-negative.")
2519-
if np.all(weights_arr == 0):
2520-
raise ValueError("Sum of weights must not be zero.")
2521-
if np.any(np.isnan(weights_arr)):
2522-
out_shape = q.shape if hasattr(q, "shape") and getattr(q, "ndim", 0) > 0 else ()
2523-
return lax.full(out_shape, np.nan, dtype=a.dtype)
2524-
weights_have_nan = np.any(np.isnan(weights_arr))
2525-
else:
2526-
weights_have_nan = False
2516+
weights_have_nan = jnp.any(jnp.isnan(weights))
2517+
if weights_have_nan:
2518+
out_shape = q.shape if hasattr(q, "shape") and getattr(q, "ndim", 0) > 0 else ()
2519+
return lax.full(out_shape, np.nan, dtype=a.dtype)
25272520

25282521
if squash_nans:
25292522
nan_mask = ~lax_internal._isnan(a)
@@ -2537,29 +2530,16 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25372530
cum_weights = lax.cumsum(weights_sorted, axis=axis)
25382531
cum_weights_norm = lax.div(cum_weights, total_weight)
25392532

2540-
slice_sizes = list(a_sorted.shape)
2541-
slice_sizes[axis] = 1
2542-
dnums = lax.GatherDimensionNumbers(
2543-
offset_dims=tuple(range(
2544-
0,
2545-
len(a_sorted.shape) if keepdims else len(a_sorted.shape) - 1)),
2546-
collapsed_slice_dims=() if keepdims else (axis,),
2547-
start_index_map=(axis,))
2548-
25492533
def _weighted_quantile(qi, weights_have_nan=weights_have_nan):
2550-
index_dtype = dtypes.canonicalize_dtype(int)
2534+
index_dtype = dtypes.default_int_dtype()
25512535
idx = sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype)
25522536
idx = lax.clamp(0, idx, a_sorted.shape[axis] - 1)
2553-
slicer = [slice(None)] * a_sorted.ndim
2554-
slicer[axis] = idx
2555-
val = a_sorted[tuple(slicer)]
2537+
val = jnp.take_along_axis(a_sorted, jnp.expand_dims(idx, axis), axis)
25562538

25572539
idx_prev = lax.clamp(idx - 1, 0, a_sorted.shape[axis] - 1)
2558-
slicer_prev = slicer.copy()
2559-
slicer_prev[axis] = idx_prev
2560-
val_prev = a_sorted[tuple(slicer_prev)]
2561-
cw_prev = cum_weights_norm[tuple(slicer_prev)]
2562-
cw_next = cum_weights_norm[tuple(slicer)]
2540+
val_prev = jnp.take_along_axis(a_sorted, jnp.expand_dims(idx_prev, axis), axis)
2541+
cw_prev = jnp.take_along_axis(cum_weights_norm, jnp.expand_dims(idx_prev, axis), axis)
2542+
cw_next = jnp.take_along_axis(cum_weights_norm, jnp.expand_dims(idx, axis), axis)
25632543

25642544
if method == "linear":
25652545
denom = cw_next - cw_prev
@@ -2603,12 +2583,12 @@ def _weighted_quantile(qi, weights_have_nan=weights_have_nan):
26032583
high_weight = lax.sub(q, low)
26042584
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
26052585

2606-
low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1))
2607-
high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1))
2608-
low = lax.convert_element_type(low, dtypes.canonicalize_dtype(int))
2609-
high = lax.convert_element_type(high, dtypes.canonicalize_dtype(int))
2586+
low = lax.max(lax._const(low, 0), lax.min(low, counts - 1))
2587+
high = lax.max(lax._const(high, 0), lax.min(high, counts - 1))
2588+
low = lax.convert_element_type(low, int)
2589+
high = lax.convert_element_type(high, int)
26102590
out_shape = q_shape + shape_after_reduction
2611-
index = [lax.broadcasted_iota(dtypes.canonicalize_dtype(int), out_shape, dim + q_ndim)
2591+
index = [lax.broadcasted_iota(int, out_shape, dim + q_ndim)
26122592
for dim in range(len(shape_after_reduction))]
26132593
if keepdims:
26142594
index[axis] = low
@@ -2628,10 +2608,10 @@ def _weighted_quantile(qi, weights_have_nan=weights_have_nan):
26282608
high_weight = lax.sub(q, low)
26292609
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
26302610

2631-
low = lax.clamp(_lax_const(low, 0), low, n - 1)
2632-
high = lax.clamp(_lax_const(high, 0), high, n - 1)
2633-
low = lax.convert_element_type(low, dtypes.int_)
2634-
high = lax.convert_element_type(high, dtypes.int_)
2611+
low = lax.clamp(lax._const(low, 0), low, n - 1)
2612+
high = lax.clamp(lax._const(high, 0), high, n - 1)
2613+
low = lax.convert_element_type(low, int)
2614+
high = lax.convert_element_type(high, int)
26352615

26362616
slice_sizes = list(a_shape)
26372617
slice_sizes[axis] = 1
@@ -2662,7 +2642,7 @@ def _weighted_quantile(qi, weights_have_nan=weights_have_nan):
26622642
pred = lax.le(high_weight, _lax_const(high_weight, 0.5))
26632643
result = lax.select(pred, low_value, high_value)
26642644
elif method == "midpoint":
2665-
result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5))
2645+
result = lax.mul(lax.add(low_value, high_value), lax._const(low_value, 0.5))
26662646
elif method == "inverted_cdf":
26672647
result = high_value
26682648
else:
@@ -2721,6 +2701,10 @@ def percentile(a: ArrayLike, q: ArrayLike,
27212701
Array([1., 3., 4.], dtype=float32)
27222702
"""
27232703
check_arraylike("percentile", a, q)
2704+
if weights is None:
2705+
a, q = ensure_arraylike("percentile", a, q)
2706+
else:
2707+
a, q, weights = ensure_arraylike("percentile", a, q, weights)
27242708
q, = promote_dtypes_inexact(q)
27252709
if not isinstance(interpolation, DeprecatedArg):
27262710
deprecations.warn(
@@ -2781,6 +2765,10 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
27812765
Array([1.5, 3. , 4.5], dtype=float32)
27822766
"""
27832767
check_arraylike("nanpercentile", a, q)
2768+
if weights is None:
2769+
a, q = ensure_arraylike("nanpercentile", a, q)
2770+
else:
2771+
a, q, weights = ensure_arraylike("nanpercentile", a, q, weights)
27842772
q, = promote_dtypes_inexact(q)
27852773
q = q / 100
27862774
if not isinstance(interpolation, DeprecatedArg):

tests/lax_numpy_reducers_test.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,12 @@
2323
from absl.testing import parameterized
2424

2525
import numpy as np
26-
from jax._src.numpy.reductions import _quantile
2726

2827
import jax
2928
from jax import numpy as jnp
3029

3130
from jax._src import config
3231
from jax._src import dtypes
33-
from jax._src.numpy.reductions import quantile
3432
from jax._src import test_util as jtu
3533
from jax._src.util import NumpyComplexWarning
3634

@@ -770,31 +768,31 @@ def test_weighted_quantile_all_weights_one(self):
770768
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
771769
weights = jnp.ones_like(a)
772770
q = jnp.array([0.25, 0.5, 0.75])
773-
result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
771+
result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
774772
expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf")
775773
np.testing.assert_allclose(np.array(result), expected, rtol=1e-6)
776774

777775
def test_weighted_quantile_multiple_q(self):
778776
a = jnp.arange(10, dtype=float)
779777
weights = jnp.ones_like(a)
780778
q = jnp.array([0.25, 0.5, 0.75])
781-
result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
779+
result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
782780
expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf")
783781
np.testing.assert_allclose(np.array(result), expected, rtol=1e-6)
784782

785783
def test_weighted_quantile_keepdims(self):
786784
a = jnp.array([1, 2, 3, 4], dtype=float)
787785
weights = jnp.array([1, 1, 1, 1], dtype=float)
788786
q = 0.5
789-
result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=True, squash_nans=False, weights=weights)
787+
result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=True, squash_nans=False, weights=weights)
790788
expected = np.quantile(np.array(a), np.array(q), axis=0, keepdims=True, weights=np.array(weights), method="inverted_cdf")
791789
np.testing.assert_allclose(np.array(result), expected, rtol=1e-6)
792790

793791
def test_weighted_quantile_linear(self):
794792
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
795793
weights = jnp.array([1, 2, 1, 1, 1], dtype=float)
796794
q = jnp.array([0.5])
797-
result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
795+
result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
798796
expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf")
799797
np.testing.assert_allclose(np.array(result), expected, rtol=1e-6)
800798

@@ -803,35 +801,35 @@ def test_weighted_quantile_negative_weights(self):
803801
weights = jnp.array([1, -1, 1, 1, 1], dtype=float)
804802
q = jnp.array([0.5])
805803
with pytest.raises(ValueError):
806-
_quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
804+
jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
807805

808806
def test_weighted_quantile_all_weights_zero(self):
809807
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
810808
weights = jnp.zeros_like(a)
811809
q = jnp.array([0.5])
812810
with pytest.raises(ValueError):
813-
_quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
811+
jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
814812

815813
def test_weighted_quantile_weights_with_nan(self):
816814
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
817815
weights = jnp.array([1, np.nan, 1, 1, 1], dtype=float)
818816
q = jnp.array([0.5])
819-
result = _quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
817+
result = jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
820818
assert np.isnan(np.array(result)).all()
821819

822820
def test_weighted_quantile_scalar_q(self):
823821
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
824822
weights = jnp.array([1, 2, 1, 1, 1], dtype=float)
825823
q = 0.5
826-
result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
824+
result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
827825
assert jnp.issubdtype(result.dtype, jnp.floating)
828826
assert result.shape == ()
829827

830828
def test_weighted_quantile_jit(self):
831829
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
832830
weights = jnp.array([1, 2, 1, 1, 1], dtype=float)
833831
q = jnp.array([0.25, 0.5, 0.75])
834-
quantile_jit = jax.jit(lambda a, q, weights: _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights))
832+
quantile_jit = jax.jit(lambda a, q, weights: jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights))
835833
result = quantile_jit(a, q, weights)
836834
expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf")
837835
np.testing.assert_allclose(np.array(result), expected, rtol=1e-6)

0 commit comments

Comments
 (0)