-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Does lax cond short circuit? #3103
Comments
Both sides of the conditional are traced, meaning both branch functions are evaluated with tracer objects that don't do any computation in order to discover the operations to be compiled with |
got it, thanks! |
One detail to add on: only the operations in each branch that have a data dependence on the explicit branch operands will be delayed; operations with no data dependence on the operands are executed at trace time when not using a Here's an example: @jit
def f(x):
return lax.cond(x > 0,
(), lambda _: np.sin(x),
(), lambda _: np.cos(x)) On the current master branch, both To ensure only one side is executed per application of @jit
def f(x):
return lax.cond(x > 0,
x, lambda x: np.sin(x),
x, lambda x: np.cos(x)) This is a weird quirk of our tracing implementation, and we're working on revising it. Hoping to land a fix in the next couple weeks! |
interesting, thanks @mattjj ! One more q: is there a way to determine if certain jax code was executed? Would be very useful for debugging! |
I think you mean executed as in evaluated, like to ensure that only one side of the cond was taken rather than both. (If by "executed" you mean "traced" then you can use Python print function calls.) Since XLA HLO doesn't have errors, without using side-effects then I think the only way to do it is via non-termination, like put an infinite Otherwise you'd need to use a side-effect. Two readily-available side-effects are time and heat (perhaps those are the same thing...); that is, if More seriously, there are side-effects in XLA, but we have only exposed them in experimental APIs (infeed and outfeed). I don't necessarily recommend using them right now, but the host callback outfeed mechanism is the perfect API for this (cc @gnecula). Instead of verifying what was executed, it might be good enough to just look at the XLA HLO programs we send to the compiler, then trust in the XLA HLO operational semantics around conditionals. If that works, I can tell you some ways to print the XLA HLO being generated. Then at least you could see the funny hoisting behavior I alluded to, and also see when it's fixed. Would that be useful? |
I think this should be added to the FAQ, or documented explicitly somewhere |
@joaogui1 I'm reading these threads to supplement the documentation. The comments in these issues are filled with good insight. |
@mattjj Thank you for the always-so-helpful response! This is more than enough to move forward. |
I am tempted to close this issue. I do not quite understand what needs to be documented. Is it the fact that the only way to tell if a code was executed is to use id_print? Or is it the hoisting behavior? (The latter is going to change soon) In general, XLA reserves the right to execute (or not execute) code as long as one cannot tell by the result of the computation. I am closing for now, please re-open if you feed it needs to stay open. |
I think the import jax
def f(x):
return jax.lax.cond(x > 0, lambda x: x**2, lambda x: jax.lax.while_loop(lambda x: True, lambda _: _, 0), x) then doing |
It's correct that after #3370 the data dependence behavior is gone; all JAX operations (e.g. That code doesn't run an infinite loop though. Both sides are staged out, but it's not that both sides are evaluated. As always, only one side of the cond is evaluated. |
Maybe this is confusing because staged programming can be confusing, and JAX is (in part) a staged system. When we run The purpose of this first step is not to perform any numerical operations, like FLOPs or (in this case) integer arithmetic. Indeed, the JAX tracer objects used here don't even carry concrete integer values with them! Instead, this step is just setup: it's building a jaxpr (JAX IR) representation for each side of the cond, each symbolically representing the computation that would happen on each branch, so that one of the two sides can be evaluated later. (Notice that So in this first step, both Python lambdas are evaluated because we want to construct jaxpr programs representing the computation in their bodies (without actually performing either computation yet!). That's always been true, before and after #3370. The second step is where the actual numerical evaluation happens. After we've built jaxpr representations of both sides of the cond, then we look at the value of the boolean predicate (the first argument to That's why the above example code doesn't infinite loop: since Hopefully that is a bit more explanation, or at least definition of terminology, for why the above example doesn't infinite-loop. But luckily you can just try it yourself and see :) |
Here's code that would infinite loop, at tracing/staging time: import jax
def f(x):
return jax.lax.cond(x > 0, lambda x: x**2, lambda x: infinite_loop(), x)
def infinite_loop():
while True: pass The reason is that we have a Python expression which can't be evaluated at tracing/staging time (i.e. a Python infinite loop), rather than a staged infinite loop (i.e. one in the staged-out jaxpr language, as we get with |
Is this behaviour happening in jnp.where? Concretely:
where x is an array (I use this syntax a lot for buffers) Which case is it
The documentation says to use jax.numpy inplace of the lower level lax methods. Does that still hold for this situation? |
If you are in op-by-op mode (i.e. outside JIT), the result of If you are in a JIT context, like this:
then the XLA compiler has the freeedom to avoid computing |
Perfect, thank you for the speedy and concrete reply |
Thanks for the thorough explanations @mattjj. |
It seems to me that the staged out branches aren't really staged out, for example, the following code will OOM. import jax
import jax.numpy as jnp
from jax import lax
def a(key):
return jax.random.normal(key, (32,)*10).sum()
def b(key):
return jnp.array(1.)
lax.cond(True, b, a, jax.random.PRNGKey(0)) # this oom
jax.jit(lambda x, pred: lax.cond(pred, b, a, x))(jax.random.PRNGKey(0), True) # this oom too |
Hello everyone, I'm working on a transformer model where I aim to apply one of two attention mechanisms to each attention head. One mechanism is computationally expensive, while the other is more cost-effective. To switch between the two mechanisms, I've employed conditional logic using jax.lax.cond. Here's the curious part: when I exclusively use the "cheap" attention mechanism, I achieve an inference latency of approximately 33 seconds. On the other hand, if I opt for only the "expensive" mechanism, the latency spikes to around 150 seconds. However, when I use jax.lax.cond to conditionally apply either mechanism based on a certain predicate, the inference latency remains consistently around 150 seconds, regardless of how I set the predicate. This suggests that both branches of the conditional are being evaluated in terms of computational cost, which isn't the intended behavior. Specifically, even when I set the predicate to predominantly trigger the cheaper operation, the latency doesn't decrease, staying in the ballpark of the more expensive mechanism. I would greatly appreciate insights or suggestions. Am I misunderstanding how jax.lax.cond works, or could there be another underlying issue that I'm not accounting for? cc: @jakevdp @mattjj
|
It's hard to say without seeing your full code, but is it possible that your module is being executed in a |
Thanks for your comments. I am modifying the huggingface flax vit implementation by changing class FlaxViTSelfAttention in this codebase. My new FlaxViTSelfAttention is below for your reference.
|
If jax.vmap(lambda x: jax.lax.cond(x > 0, some_expensive_function, lambda _: 0, x))(xs) vs jax.vmap(lambda x: jax.lax.cond(x > 0, lambda: some_expensive_function(x), lambda: 0))(xs) ? |
Those two expressions will lower to the same sequence of operations. You can check this by printing their jaxpr: import jax
import jax.numpy as jnp
some_expensive_function = jnp.sin # stand-in
f1 = jax.vmap(lambda x: jax.lax.cond(x > 0, some_expensive_function, lambda _: 0.0, x))
f2 = jax.vmap(lambda x: jax.lax.cond(x > 0, lambda: some_expensive_function(x), lambda: 0.0))
x = jnp.arange(10.0)
print(jax.make_jaxpr(f1)(x))
print(jax.make_jaxpr(f2)(x)) Both print this:
|
So I guess if I want to avoid computing expensive function on inputs that are not needed the only way is to do ys = jax.vmap(expensive_function)(xs[cond])
zs = jax.vmap(cheap_function)(xs[~cond]) ? |
Yes, that will work. You can do so with |
(but note that |
Hello! I have a function
f
that wraps two functions, one of which is very expensive (f_1
), the other (f_2
) is not (they return the same shaped array). If one of the arguments tof
is false, we do not need the expensive function. Ultimately, I wrap this inside a jitted function, so I must use lax.cond to splitf
intof_1
andf_2
. Does this buy me anything, or do both sides of the conditional have to be executed because of the way jax works. Thanks!The text was updated successfully, but these errors were encountered: