From 20e5090b61f4cb18d366a36b2b32fcce57486740 Mon Sep 17 00:00:00 2001 From: michaelmarien Date: Mon, 14 Feb 2022 21:41:30 +0100 Subject: [PATCH] Add a warning to random.choice to notify users of the ill-defined behaviour when requesting more samples than non-zero probabilities and replace=False --- jax/_src/random.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/jax/_src/random.py b/jax/_src/random.py index 0e7c067eb892..6b51fb208a8e 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -435,6 +435,11 @@ def choice(key: KeyArray, axis: int = 0) -> jnp.ndarray: """Generates a random sample from a given array. + .. warning:: + If ``p`` has fewer non-zero elements than the requested number of samples, + as specified in ``shape``, and ``replace=False``, the output of this + function is ill-defined. Please make sure to use appropriate inputs. + Args: key: a PRNG key used as the random key. a : array or int. If an ndarray, a random sample is generated from