2525
2626import jax
2727from jax import lax
28+ from jax import numpy as jnp
2829from jax ._src import api
2930from jax ._src import core
3031from jax ._src import deprecations
3132from jax ._src import dtypes
3233from jax ._src .numpy .util import (
33- _broadcast_to , check_arraylike , _complex_elem_type ,
34+ _broadcast_to , check_arraylike , ensure_arraylike , _complex_elem_type ,
3435 promote_dtypes_inexact , promote_dtypes_numeric , _where )
3536from jax ._src .lax import lax as lax_internal
3637from jax ._src .typing import Array , ArrayLike , DType , DTypeLike , DeprecatedArg
@@ -2379,6 +2380,10 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
23792380 Array([2., 4., 7.], dtype=float32)
23802381 """
23812382 check_arraylike ("quantile" , a , q )
2383+ if weights is None :
2384+ a , q = ensure_arraylike ("quantile" , a , q )
2385+ else :
2386+ a , q , weights = ensure_arraylike ("quantile" , a , q , weights )
23822387 if overwrite_input or out is not None :
23832388 raise ValueError ("jax.numpy.quantile does not support overwrite_input=True "
23842389 "or out != None" )
@@ -2435,6 +2440,10 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24352440 Array([1.5, 3. , 4.5], dtype=float32)
24362441 """
24372442 check_arraylike ("nanquantile" , a , q )
2443+ if weights is None :
2444+ a , q = ensure_arraylike ("nanquantile" , a , q )
2445+ else :
2446+ a , q , weights = ensure_arraylike ("nanquantile" , a , q , weights )
24382447 if overwrite_input or out is not None :
24392448 msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
24402449 "out != None" )
@@ -2445,10 +2454,9 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24452454 return _quantile (lax .asarray (a ), lax .asarray (q ), axis , method , keepdims , True , weights )
24462455
24472456def _quantile (a : Array , q : Array , axis : int | tuple [int , ...] | None ,
2448- method : str , keepdims : bool , squash_nans : bool , weights : ArrayLike | None = None ) -> Array :
2457+ method : str , keepdims : bool , squash_nans : bool , weights : Array | None = None ) -> Array :
24492458 if method not in ["linear" , "lower" , "higher" , "midpoint" , "nearest" , "inverted_cdf" ]:
24502459 raise ValueError ("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest', or 'inverted_cdf'" )
2451- a , = promote_dtypes_inexact (a )
24522460 keepdim = []
24532461 if dtypes .issubdtype (a .dtype , np .complexfloating ):
24542462 raise ValueError ("quantile does not support complex input, as the operation is poorly defined." )
@@ -2477,9 +2485,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24772485 axis = _canonicalize_axis (- 1 , a .ndim )
24782486 else :
24792487 axis = _canonicalize_axis (axis , a .ndim )
2480-
2481- # Ensure q is an array and inexact
2482- q = lax .asarray (q )
2488+
24832489 q , = promote_dtypes_inexact (q )
24842490
24852491 q_shape = q .shape
@@ -2489,8 +2495,10 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24892495
24902496 a_shape = a .shape
24912497 # Handle weights
2492- if weights is not None :
2493- a , weights = promote_dtypes_inexact (a , weights )
2498+ if weights is None :
2499+ a , = promote_dtypes_inexact (a )
2500+ else :
2501+ a , q = promote_dtypes_inexact (a , q )
24942502 a_shape = a .shape
24952503 w_shape = np .shape (weights )
24962504 if w_shape != a_shape :
@@ -2502,28 +2510,13 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25022510 raise ValueError ("Length of weights not compatible with specified axis." )
25032511 resh = [1 ] * a .ndim
25042512 resh [axis ] = w_shape [0 ]
2505- weights = lax .reshape ( lax_internal . asarray ( weights ), tuple ( resh ) )
2513+ weights = lax .expand_dims ( weights , axis )
25062514 weights = _broadcast_to (weights , a .shape )
25072515
2508- if isinstance (weights , core .Tracer ):
2509- weights_arr = None
2510- else :
2511- try :
2512- weights_arr = np .asarray (weights )
2513- except Exception :
2514- weights_arr = None
2515-
2516- if weights_arr is not None :
2517- if np .any (weights_arr < 0 ):
2518- raise ValueError ("Weights must be non-negative." )
2519- if np .all (weights_arr == 0 ):
2520- raise ValueError ("Sum of weights must not be zero." )
2521- if np .any (np .isnan (weights_arr )):
2522- out_shape = q .shape if hasattr (q , "shape" ) and getattr (q , "ndim" , 0 ) > 0 else ()
2523- return lax .full (out_shape , np .nan , dtype = a .dtype )
2524- weights_have_nan = np .any (np .isnan (weights_arr ))
2525- else :
2526- weights_have_nan = False
2516+ weights_have_nan = jnp .any (jnp .isnan (weights ))
2517+ if weights_have_nan :
2518+ out_shape = q .shape if hasattr (q , "shape" ) and getattr (q , "ndim" , 0 ) > 0 else ()
2519+ return lax .full (out_shape , np .nan , dtype = a .dtype )
25272520
25282521 if squash_nans :
25292522 nan_mask = ~ lax_internal ._isnan (a )
@@ -2537,29 +2530,16 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25372530 cum_weights = lax .cumsum (weights_sorted , axis = axis )
25382531 cum_weights_norm = lax .div (cum_weights , total_weight )
25392532
2540- slice_sizes = list (a_sorted .shape )
2541- slice_sizes [axis ] = 1
2542- dnums = lax .GatherDimensionNumbers (
2543- offset_dims = tuple (range (
2544- 0 ,
2545- len (a_sorted .shape ) if keepdims else len (a_sorted .shape ) - 1 )),
2546- collapsed_slice_dims = () if keepdims else (axis ,),
2547- start_index_map = (axis ,))
2548-
25492533 def _weighted_quantile (qi , weights_have_nan = weights_have_nan ):
2550- index_dtype = dtypes .canonicalize_dtype ( int )
2534+ index_dtype = dtypes .default_int_dtype ( )
25512535 idx = sum (lax .lt (cum_weights_norm , qi ), axis = axis , dtype = index_dtype )
25522536 idx = lax .clamp (0 , idx , a_sorted .shape [axis ] - 1 )
2553- slicer = [slice (None )] * a_sorted .ndim
2554- slicer [axis ] = idx
2555- val = a_sorted [tuple (slicer )]
2537+ val = jnp .take_along_axis (a_sorted , jnp .expand_dims (idx , axis ), axis )
25562538
25572539 idx_prev = lax .clamp (idx - 1 , 0 , a_sorted .shape [axis ] - 1 )
2558- slicer_prev = slicer .copy ()
2559- slicer_prev [axis ] = idx_prev
2560- val_prev = a_sorted [tuple (slicer_prev )]
2561- cw_prev = cum_weights_norm [tuple (slicer_prev )]
2562- cw_next = cum_weights_norm [tuple (slicer )]
2540+ val_prev = jnp .take_along_axis (a_sorted , jnp .expand_dims (idx_prev , axis ), axis )
2541+ cw_prev = jnp .take_along_axis (cum_weights_norm , jnp .expand_dims (idx_prev , axis ), axis )
2542+ cw_next = jnp .take_along_axis (cum_weights_norm , jnp .expand_dims (idx , axis ), axis )
25632543
25642544 if method == "linear" :
25652545 denom = cw_next - cw_prev
@@ -2603,12 +2583,12 @@ def _weighted_quantile(qi, weights_have_nan=weights_have_nan):
26032583 high_weight = lax .sub (q , low )
26042584 low_weight = lax .sub (_lax_const (high_weight , 1 ), high_weight )
26052585
2606- low = lax .max (_lax_const (low , 0 ), lax .min (low , counts - 1 ))
2607- high = lax .max (_lax_const (high , 0 ), lax .min (high , counts - 1 ))
2608- low = lax .convert_element_type (low , dtypes . canonicalize_dtype ( int ) )
2609- high = lax .convert_element_type (high , dtypes . canonicalize_dtype ( int ) )
2586+ low = lax .max (lax . _const (low , 0 ), lax .min (low , counts - 1 ))
2587+ high = lax .max (lax . _const (high , 0 ), lax .min (high , counts - 1 ))
2588+ low = lax .convert_element_type (low , int )
2589+ high = lax .convert_element_type (high , int )
26102590 out_shape = q_shape + shape_after_reduction
2611- index = [lax .broadcasted_iota (dtypes . canonicalize_dtype ( int ) , out_shape , dim + q_ndim )
2591+ index = [lax .broadcasted_iota (int , out_shape , dim + q_ndim )
26122592 for dim in range (len (shape_after_reduction ))]
26132593 if keepdims :
26142594 index [axis ] = low
@@ -2628,10 +2608,10 @@ def _weighted_quantile(qi, weights_have_nan=weights_have_nan):
26282608 high_weight = lax .sub (q , low )
26292609 low_weight = lax .sub (_lax_const (high_weight , 1 ), high_weight )
26302610
2631- low = lax .clamp (_lax_const (low , 0 ), low , n - 1 )
2632- high = lax .clamp (_lax_const (high , 0 ), high , n - 1 )
2633- low = lax .convert_element_type (low , dtypes . int_ )
2634- high = lax .convert_element_type (high , dtypes . int_ )
2611+ low = lax .clamp (lax . _const (low , 0 ), low , n - 1 )
2612+ high = lax .clamp (lax . _const (high , 0 ), high , n - 1 )
2613+ low = lax .convert_element_type (low , int )
2614+ high = lax .convert_element_type (high , int )
26352615
26362616 slice_sizes = list (a_shape )
26372617 slice_sizes [axis ] = 1
@@ -2662,7 +2642,7 @@ def _weighted_quantile(qi, weights_have_nan=weights_have_nan):
26622642 pred = lax .le (high_weight , _lax_const (high_weight , 0.5 ))
26632643 result = lax .select (pred , low_value , high_value )
26642644 elif method == "midpoint" :
2665- result = lax .mul (lax .add (low_value , high_value ), _lax_const (low_value , 0.5 ))
2645+ result = lax .mul (lax .add (low_value , high_value ), lax . _const (low_value , 0.5 ))
26662646 elif method == "inverted_cdf" :
26672647 result = high_value
26682648 else :
@@ -2721,6 +2701,10 @@ def percentile(a: ArrayLike, q: ArrayLike,
27212701 Array([1., 3., 4.], dtype=float32)
27222702 """
27232703 check_arraylike ("percentile" , a , q )
2704+ if weights is None :
2705+ a , q = ensure_arraylike ("percentile" , a , q )
2706+ else :
2707+ a , q , weights = ensure_arraylike ("percentile" , a , q , weights )
27242708 q , = promote_dtypes_inexact (q )
27252709 if not isinstance (interpolation , DeprecatedArg ):
27262710 deprecations .warn (
@@ -2781,6 +2765,10 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
27812765 Array([1.5, 3. , 4.5], dtype=float32)
27822766 """
27832767 check_arraylike ("nanpercentile" , a , q )
2768+ if weights is None :
2769+ a , q = ensure_arraylike ("nanpercentile" , a , q )
2770+ else :
2771+ a , q , weights = ensure_arraylike ("nanpercentile" , a , q , weights )
27842772 q , = promote_dtypes_inexact (q )
27852773 q = q / 100
27862774 if not isinstance (interpolation , DeprecatedArg ):
0 commit comments