diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index df74981162c7..bf191d46e908 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -155,8 +155,13 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, operand=_no_operand_sentinel, linear=None): """Conditionally apply ``true_fun`` or ``false_fun``. + Wraps XLA's `Conditional + `_ + operator. + Provided arguments are correctly typed, ``cond()`` has equivalent - semantics to this Python implementation:: + semantics to this Python implementation, where ``pred`` must be a + scalar type:: def cond(pred, true_fun, false_fun, *operands): if pred: @@ -164,7 +169,14 @@ def cond(pred, true_fun, false_fun, *operands): else: return false_fun(*operands) - ``pred`` must be a scalar type. + When run in a :func:`jax.disable_jit` context, the above is roughly how + ``cond()`` is evaluated. + + Using ``cond`` rather than :func:`jax.lax.select` signals the intent that only one + of the two branches will be executed, although the compiler may choose to execute + both branches if it is deemed advantageous. Note also that when transformed with + :func:`~jax.vmap` the batched-cond will be converted to a :func:`~jax.lax.select` + operation. Args: pred: Boolean scalar type, indicating which branch function to apply. diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ca62b2f09959..f3248606506d 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -908,9 +908,25 @@ def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array: return rev_p.bind(operand, dimensions=tuple(dimensions)) def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array: - """Wraps XLA's `Select + """Selects between two branches based on a boolean predicate. + + Wraps XLA's `Select `_ operator. + + In general :func:`~jax.lax.select` leads to evaluation of both branches, although + the compiler may elide computations if possible. For a similar function that + usually evaluates only a single branch, see :func:`~jax.lax.cond`. + + Args: + pred: boolean array + on_true: array containing entries to return where ``pred`` is True. Must have + the same shape as ``pred``, and the same shape and dtype as ``on_false``. + on_false: array containing entries to return where ``pred`` is False. Must have + the same shape as ``pred``, and the same shape and dtype as ``on_true``. + + Returns: + result: array with same shape and dtype as ``on_true`` and ``on_false``. """ # Caution! The select_n_p primitive has the *opposite* order of arguments to # select(). This is because it implements `select_n`.