Skip to content

avoid JIT compilation of some code section #6070

Answered by jakevdp
backpropper asked this question in Q&A
Discussion options

You must be logged in to vote

There is not any way to mark a section of a JIT-compiled function to be run outside of JIT, but there are other options.

Suppose you have a function like this that you would like to JIT compile; as written it results in a TracerArrayConversionError:

import jax.numpy as jnp
from jax import jit, random

@jit
def f(x):
  y = jnp.sin(x)
  indices, = np.nonzero(abs(y) < 0.5)
  return indices.sum()

x = random.uniform(random.PRNGKey(0), (1000,))
f(x)
# TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object
# Traced<ShapedArray(bool[1000])>with<DynamicJaxprTrace(level=0/1)>
# (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.Tra…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by backpropper
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants