avoid JIT compilation of some code section #6070
-
I am using |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
There is not any way to mark a section of a JIT-compiled function to be run outside of JIT, but there are other options. Suppose you have a function like this that you would like to JIT compile; as written it results in a import jax.numpy as jnp
from jax import jit, random
@jit
def f(x):
y = jnp.sin(x)
indices, = np.nonzero(abs(y) < 0.5)
return indices.sum()
x = random.uniform(random.PRNGKey(0), (1000,))
f(x)
# TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object
# Traced<ShapedArray(bool[1000])>with<DynamicJaxprTrace(level=0/1)>
# (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError) If you want all parts of the function except the line with @jit
def _preprocess(x):
return jnp.sin(x)
@jit
def _postprocess(indices):
return indices.sum()
def f(x):
y = _preprocess(x)
indices, = np.nonzero(abs(y) < 0.5)
return _postprocess(indices)
x = random.uniform(random.PRNGKey(0), (1000,))
f(x) # DeviceArray(271019, dtype=int32) But often an approach like this might be inconvenient (e.g. if the non-JIT compatible pieces are deep in the function) or not very efficient, as the non-JIT compiled pieces cannot be optimized by XLA. Instead, it is often possible to re-express your computation in terms of JIT-compatible operations. For example: @jit
def f(x):
y = jnp.sin(x)
indices = jnp.where(abs(y) < 0.5, jnp.arange(len(y)), 0)
return indices.sum()
x = random.uniform(random.PRNGKey(0), (1000,))
f(x) # DeviceArray(271019, dtype=int32) For some background on why code may or may not be compatible with JIT, see 🔪 JAX - The Sharp Bits 🔪 : Control Flow |
Beta Was this translation helpful? Give feedback.
There is not any way to mark a section of a JIT-compiled function to be run outside of JIT, but there are other options.
Suppose you have a function like this that you would like to JIT compile; as written it results in a
TracerArrayConversionError
: