Skip to content

Commit

Permalink
document that under disable_jit, individual primitives are still comp…
Browse files Browse the repository at this point in the history
…iled
  • Loading branch information
jakevdp committed Feb 5, 2024
1 parent e224c3d commit 82611eb
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ def disable_jit(disable: bool = True):
`cond` functions passed to higher-level primitives like :func:`~jax.lax.scan` and
:func:`~jax.lax.while_loop`, JIT used in implementations of :mod:`jax.numpy` functions,
and any other case where :func:`jit` is used within an API's implementation.
Note however that even under `disable_jit`, individual primitive operations
will still be compiled by XLA as in normal eager op-by-op execution.
Values that have a data dependence on the arguments to a jitted function are
traced and abstracted. For example, an abstract value may be a
Expand Down

0 comments on commit 82611eb

Please sign in to comment.