Skip to content

Commit

Permalink
DOC: clarify behavior of lax.cond & lax.select
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 6, 2023
1 parent 7cfea0a commit c9c6263
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
13 changes: 11 additions & 2 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,25 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
operand=_no_operand_sentinel, linear=None):
"""Conditionally apply ``true_fun`` or ``false_fun``.
Wraps XLA's `Conditional
<https://www.tensorflow.org/xla/operation_semantics#conditional>`_
operator.
Provided arguments are correctly typed, ``cond()`` has equivalent
semantics to this Python implementation::
semantics to this Python implementation, where ``pred`` must be a
scalar type::
def cond(pred, true_fun, false_fun, *operands):
if pred:
return true_fun(*operands)
else:
return false_fun(*operands)
``pred`` must be a scalar type.
In contrast with :func:`jax.lax.select`, using ``cond`` indicates that only one of
the two branches is executed (up to compiler rewrites and optimizations).
However, when transformed with :func:`~jax.vmap` to operate over a batch of
predicates, ``cond`` is converted to :func:`~jax.lax.select`.
Args:
pred: Boolean scalar type, indicating which branch function to apply.
Expand Down
18 changes: 17 additions & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,9 +908,25 @@ def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array:
return rev_p.bind(operand, dimensions=tuple(dimensions))

def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array:
"""Wraps XLA's `Select
"""Selects between two branches based on a boolean predicate.
Wraps XLA's `Select
<https://www.tensorflow.org/xla/operation_semantics#select>`_
operator.
In general :func:`~jax.lax.select` leads to evaluation of both branches, although
the compiler may elide computations if possible. For a similar function that
usually evaluates only a single branch, see :func:`~jax.lax.cond`.
Args:
pred: boolean array
on_true: array containing entries to return where ``pred`` is True. Must have
the same shape as ``pred``, and the same shape and dtype as ``on_false``.
on_false: array containing entries to return where ``pred`` is False. Must have
the same shape as ``pred``, and the same shape and dtype as ``on_true``.
Returns:
result: array with same shape and dtype as ``on_true`` and ``on_false``.
"""
# Caution! The select_n_p primitive has the *opposite* order of arguments to
# select(). This is because it implements `select_n`.
Expand Down

0 comments on commit c9c6263

Please sign in to comment.