Skip to content

Commit

Permalink
DOC: clarify behavior of lax.cond & lax.select
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 17, 2022
1 parent 11a1795 commit 4114ff2
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
16 changes: 14 additions & 2 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,28 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
operand=_no_operand_sentinel, linear=None):
"""Conditionally apply ``true_fun`` or ``false_fun``.
Wraps XLA's `Conditional
<https://www.tensorflow.org/xla/operation_semantics#conditional>`_
operator.
Provided arguments are correctly typed, ``cond()`` has equivalent
semantics to this Python implementation::
semantics to this Python implementation, where ``pred`` must be a
scalar type::
def cond(pred, true_fun, false_fun, *operands):
if pred:
return true_fun(*operands)
else:
return false_fun(*operands)
``pred`` must be a scalar type.
When run in a :func:`jax.disable_jit` context, the above is roughly how
``cond()`` is evaluated.
Using ``cond`` rather than :func:`jax.lax.select` signals the intent that only one
of the two branches will be executed, although the compiler may choose to execute
both branches if it is deemed advantageous. Note also that when transformed with
:func:`~jax.vmap` the batched-cond will be converted to a :func:`~jax.lax.select`
operation.
Args:
pred: Boolean scalar type, indicating which branch function to apply.
Expand Down
18 changes: 17 additions & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,9 +908,25 @@ def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array:
return rev_p.bind(operand, dimensions=tuple(dimensions))

def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array:
"""Wraps XLA's `Select
"""Selects between two branches based on a boolean predicate.
Wraps XLA's `Select
<https://www.tensorflow.org/xla/operation_semantics#select>`_
operator.
In general :func:`~jax.lax.select` leads to evaluation of both branches, although
the compiler may elide computations if possible. For a similar function that
usually evaluates only a single branch, see :func:`~jax.lax.cond`.
Args:
pred: boolean array
on_true: array containing entries to return where ``pred`` is True. Must have
the same shape as ``pred``, and the same shape and dtype as ``on_false``.
on_false: array containing entries to return where ``pred`` is False. Must have
the same shape as ``pred``, and the same shape and dtype as ``on_true``.
Returns:
result: array with same shape and dtype as ``on_true`` and ``on_false``.
"""
# Caution! The select_n_p primitive has the *opposite* order of arguments to
# select(). This is because it implements `select_n`.
Expand Down

0 comments on commit 4114ff2

Please sign in to comment.