How to detect if within jit? #9241
Replies: 2 comments 4 replies
-
I don't think there's any public API for this, but you can detect it based on the side-effects, e.g. the fact that array creation results in a tracer within JIT and other transforms: import jax.numpy as jnp
from jax import core, jit
def f():
is_jit = isinstance(jnp.array(0), core.Tracer)
return is_jit
print(f())
# False
print(jit(f)())
# True (Note however this will return That said, forking a function's behavior based on whether it is being called within a transform is probably not a good idea: it's likely to have unintended side-effects, particularly when it comes to things like jit-invariance and composability of transforms. |
Beta Was this translation helpful? Give feedback.
-
I'm interested in this as well. I would like a way to warn a user that a function is not transformable, and was thinking of creating a decorator for this. Is it possible? |
Beta Was this translation helpful? Give feedback.
-
Hi everyone, is there a way to detect from within a function, if it's being jitted e.g., to create a function that works differently when jitted?
In the following example, I'm curious if we can have
jitting
be dependent on whether or notjax.jit
is being called onfunc
:Beta Was this translation helpful? Give feedback.
All reactions