Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC: clarify behavior of lax.cond & lax.select #13589

Merged
merged 1 commit into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,25 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
operand=_no_operand_sentinel, linear=None):
"""Conditionally apply ``true_fun`` or ``false_fun``.

Wraps XLA's `Conditional
<https://www.tensorflow.org/xla/operation_semantics#conditional>`_
operator.

Provided arguments are correctly typed, ``cond()`` has equivalent
semantics to this Python implementation::
semantics to this Python implementation, where ``pred`` must be a
scalar type::

def cond(pred, true_fun, false_fun, *operands):
if pred:
return true_fun(*operands)
else:
return false_fun(*operands)

``pred`` must be a scalar type.

In contrast with :func:`jax.lax.select`, using ``cond`` indicates that only one of
the two branches is executed (up to compiler rewrites and optimizations).
However, when transformed with :func:`~jax.vmap` to operate over a batch of
predicates, ``cond`` is converted to :func:`~jax.lax.select`.

Args:
pred: Boolean scalar type, indicating which branch function to apply.
Expand Down
18 changes: 17 additions & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,9 +908,25 @@ def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array:
return rev_p.bind(operand, dimensions=tuple(dimensions))

def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array:
"""Wraps XLA's `Select
"""Selects between two branches based on a boolean predicate.

Wraps XLA's `Select
<https://www.tensorflow.org/xla/operation_semantics#select>`_
operator.

In general :func:`~jax.lax.select` leads to evaluation of both branches, although
the compiler may elide computations if possible. For a similar function that
usually evaluates only a single branch, see :func:`~jax.lax.cond`.

Args:
pred: boolean array
on_true: array containing entries to return where ``pred`` is True. Must have
the same shape as ``pred``, and the same shape and dtype as ``on_false``.
on_false: array containing entries to return where ``pred`` is False. Must have
the same shape as ``pred``, and the same shape and dtype as ``on_true``.

Returns:
result: array with same shape and dtype as ``on_true`` and ``on_false``.
"""
# Caution! The select_n_p primitive has the *opposite* order of arguments to
# select(). This is because it implements `select_n`.
Expand Down