Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does lax cond short circuit? #3103

Closed
john-heyer opened this issue May 15, 2020 · 26 comments
Closed

Does lax cond short circuit? #3103

john-heyer opened this issue May 15, 2020 · 26 comments
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. documentation question Questions for the JAX team

Comments

@john-heyer
Copy link

john-heyer commented May 15, 2020

Hello! I have a function f that wraps two functions, one of which is very expensive (f_1), the other (f_2) is not (they return the same shaped array). If one of the arguments to f is false, we do not need the expensive function. Ultimately, I wrap this inside a jitted function, so I must use lax.cond to split f into f_1 and f_2. Does this buy me anything, or do both sides of the conditional have to be executed because of the way jax works. Thanks!

@skye
Copy link
Member

skye commented May 16, 2020

Both sides of the conditional are traced, meaning both branch functions are evaluated with tracer objects that don't do any computation in order to discover the operations to be compiled with jit. This should be fast even for the expensive function, since no computation is performed. When the final jitted function is executed with real values, only one branch will be run.

@john-heyer
Copy link
Author

got it, thanks!

@mattjj
Copy link
Collaborator

mattjj commented May 16, 2020

One detail to add on: only the operations in each branch that have a data dependence on the explicit branch operands will be delayed; operations with no data dependence on the operands are executed at trace time when not using a jit, and unconditionally when using a jit.

Here's an example:

@jit
def f(x):
  return lax.cond(x > 0,
                  (), lambda _: np.sin(x),
                  (), lambda _: np.cos(x))

On the current master branch, both np.sin(x) and np.cos(x) will be evaluated on each evaluation of f(x). Another way to put it is that they'll be hoisted out of the cond entirely.

To ensure only one side is executed per application of f, we'd need to rewrite it as

@jit
def f(x):
  return lax.cond(x > 0,
                  x, lambda x: np.sin(x),
                  x, lambda x: np.cos(x))

This is a weird quirk of our tracing implementation, and we're working on revising it. Hoping to land a fix in the next couple weeks!

@mattjj mattjj added the question Questions for the JAX team label May 16, 2020
@john-heyer
Copy link
Author

interesting, thanks @mattjj ! One more q: is there a way to determine if certain jax code was executed? Would be very useful for debugging!

@mattjj
Copy link
Collaborator

mattjj commented May 16, 2020

I think you mean executed as in evaluated, like to ensure that only one side of the cond was taken rather than both. (If by "executed" you mean "traced" then you can use Python print function calls.)

Since XLA HLO doesn't have errors, without using side-effects then I think the only way to do it is via non-termination, like put an infinite lax.while_loop in one of the branches of the lax.cond.

Otherwise you'd need to use a side-effect. Two readily-available side-effects are time and heat (perhaps those are the same thing...); that is, if f_1 is very expensive perhaps you can decide whether it was executed based on how much time the computation takes, or how much heat your processor generates!

More seriously, there are side-effects in XLA, but we have only exposed them in experimental APIs (infeed and outfeed). I don't necessarily recommend using them right now, but the host callback outfeed mechanism is the perfect API for this (cc @gnecula).

Instead of verifying what was executed, it might be good enough to just look at the XLA HLO programs we send to the compiler, then trust in the XLA HLO operational semantics around conditionals. If that works, I can tell you some ways to print the XLA HLO being generated. Then at least you could see the funny hoisting behavior I alluded to, and also see when it's fixed. Would that be useful?

@joaogui1
Copy link
Contributor

joaogui1 commented May 16, 2020

I think this should be added to the FAQ, or documented explicitly somewhere

@skye skye added documentation contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. labels May 19, 2020
@NeilGirdhar
Copy link
Contributor

@joaogui1 I'm reading these threads to supplement the documentation. The comments in these issues are filled with good insight.

@john-heyer
Copy link
Author

@mattjj Thank you for the always-so-helpful response! This is more than enough to move forward.

@gnecula
Copy link
Collaborator

gnecula commented May 26, 2020

I am tempted to close this issue. I do not quite understand what needs to be documented. Is it the fact that the only way to tell if a code was executed is to use id_print? Or is it the hoisting behavior? (The latter is going to change soon)

In general, XLA reserves the right to execute (or not execute) code as long as one cannot tell by the result of the computation.

I am closing for now, please re-open if you feed it needs to stay open.

@gnecula gnecula closed this as completed May 26, 2020
@pedrofale
Copy link

pedrofale commented Oct 26, 2020

I think the jax.lax.cond API has changed since this issue was first opened and I'm not sure @mattjj's comments apply in the same way. For example, if I do

import jax

def f(x):
  return jax.lax.cond(x > 0, lambda x: x**2, lambda x: jax.lax.while_loop(lambda x: True, lambda _: _, 0), x)

then doing f(2) will run the infinite loop. How can I avoid that?

@mattjj
Copy link
Collaborator

mattjj commented Sep 17, 2021

It's correct that after #3370 the data dependence behavior is gone; all JAX operations (e.g. jnp calls, operations on JAX arrays) are staged out under a cond.

That code doesn't run an infinite loop though. Both sides are staged out, but it's not that both sides are evaluated. As always, only one side of the cond is evaluated.

image

@mattjj
Copy link
Collaborator

mattjj commented Sep 17, 2021

Maybe this is confusing because staged programming can be confusing, and JAX is (in part) a staged system.

When we run f(2) in the example code above, there are two steps it's important to distinguish. The first is evaluation of the two Python lambda bodies when the lambdas are applied to JAX tracer arguments. If we put regular builtin Python print calls in the lambda bodies, this step is when we'd see things on stdout. We often call this step "tracing" or "staging".

The purpose of this first step is not to perform any numerical operations, like FLOPs or (in this case) integer arithmetic. Indeed, the JAX tracer objects used here don't even carry concrete integer values with them! Instead, this step is just setup: it's building a jaxpr (JAX IR) representation for each side of the cond, each symbolically representing the computation that would happen on each branch, so that one of the two sides can be evaluated later. (Notice that cond is a regular Python function, so the argument expression x > 0 is evaluated before cond is even applied. The reason the other arguments to cond are lambdas is precisely so that evaluation of the expression in their bodies can be delayed to after cond is applied, and then traced!)

So in this first step, both Python lambdas are evaluated because we want to construct jaxpr programs representing the computation in their bodies (without actually performing either computation yet!). That's always been true, before and after #3370.

The second step is where the actual numerical evaluation happens. After we've built jaxpr representations of both sides of the cond, then we look at the value of the boolean predicate (the first argument to cond) and decide which branch to execute. Then only one is executed.

That's why the above example code doesn't infinite loop: since x > 0 evaluates to True, only the computation represented by the first lambda is ever evaluated numerically.

Hopefully that is a bit more explanation, or at least definition of terminology, for why the above example doesn't infinite-loop. But luckily you can just try it yourself and see :)

@mattjj
Copy link
Collaborator

mattjj commented Sep 17, 2021

Here's code that would infinite loop, at tracing/staging time:

import jax

def f(x):
  return jax.lax.cond(x > 0, lambda x: x**2, lambda x: infinite_loop(), x)

def infinite_loop():
  while True: pass

The reason is that we have a Python expression which can't be evaluated at tracing/staging time (i.e. a Python infinite loop), rather than a staged infinite loop (i.e. one in the staged-out jaxpr language, as we get with lax.while_loop).

@oliverdutton
Copy link
Contributor

oliverdutton commented Dec 8, 2021

Is this behaviour happening in jnp.where?

Concretely:

jnp.where(
   x==mask_value,
   expensive_function(x),
   mask_value
)

where x is an array (I use this syntax a lot for buffers)

Which case is it

  • Expensive function always evaluated but simply not always used, in which case limiting the buffer is important.
  • Implicit conditional is used at numerical evaluation time so there’s a minimal performance hit and having a larger buffer only has a memory cost

The documentation says to use jax.numpy inplace of the lower level lax methods. Does that still hold for this situation?

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 8, 2021

If you are in op-by-op mode (i.e. outside JIT), the result of expensive_function will always be fully evaluated.

If you are in a JIT context, like this:

@jax.jit
def f(x):
  return jnp.where(x == mask_value, expensive_function(x), mask_value)

then the XLA compiler has the freeedom to avoid computing expensive_function in its optimization pass.

@oliverdutton
Copy link
Contributor

Perfect, thank you for the speedy and concrete reply

@epignatelli
Copy link

One detail to add on: only the operations in each branch that have a data dependence on the explicit branch operands will be delayed; operations with no data dependence on the operands are executed at trace time when not using a jit, and unconditionally when using a jit.

Here's an example:

@jit
def f(x):
  return lax.cond(x > 0,
                  (), lambda _: np.sin(x),
                  (), lambda _: np.cos(x))

On the current master branch, both np.sin(x) and np.cos(x) will be evaluated on each evaluation of f(x). Another way to put it is that they'll be hoisted out of the cond entirely.

To ensure only one side is executed per application of f, we'd need to rewrite it as

@jit
def f(x):
  return lax.cond(x > 0,
                  x, lambda x: np.sin(x),
                  x, lambda x: np.cos(x))

This is a weird quirk of our tracing implementation, and we're working on revising it. Hoping to land a fix in the next couple weeks!

Thanks for the thorough explanations @mattjj.
Does v0.4.12 still behaves as described?

@mavenlin
Copy link

mavenlin commented Sep 5, 2023

It seems to me that the staged out branches aren't really staged out, for example, the following code will OOM.

import jax
import jax.numpy as jnp
from jax import lax

def a(key):
  return jax.random.normal(key, (32,)*10).sum()

def b(key):
  return jnp.array(1.)

lax.cond(True, b, a, jax.random.PRNGKey(0)) # this oom
jax.jit(lambda x, pred: lax.cond(pred, b, a, x))(jax.random.PRNGKey(0), True) # this oom too

@dhyani15
Copy link

Hello everyone,

I'm working on a transformer model where I aim to apply one of two attention mechanisms to each attention head. One mechanism is computationally expensive, while the other is more cost-effective. To switch between the two mechanisms, I've employed conditional logic using jax.lax.cond.

Here's the curious part: when I exclusively use the "cheap" attention mechanism, I achieve an inference latency of approximately 33 seconds. On the other hand, if I opt for only the "expensive" mechanism, the latency spikes to around 150 seconds. However, when I use jax.lax.cond to conditionally apply either mechanism based on a certain predicate, the inference latency remains consistently around 150 seconds, regardless of how I set the predicate. This suggests that both branches of the conditional are being evaluated in terms of computational cost, which isn't the intended behavior.

Specifically, even when I set the predicate to predominantly trigger the cheaper operation, the latency doesn't decrease, staying in the ballpark of the more expensive mechanism.

I would greatly appreciate insights or suggestions. Am I misunderstanding how jax.lax.cond works, or could there be another underlying issue that I'm not accounting for? cc: @jakevdp @mattjj

class ChooseAttention(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32
    def setup(self):
        self.alpha = self.param('alpha', nn.initializers.ones, (1, self.config.num_attention_heads, 1, 1))
    def choose_attention(self, alpha, x):
        return jax.lax.cond(alpha[0, 0, 0, 0] < 0.5, lambda x: x / x.shape[-1],lambda x: nn.relu(x)/(jnp.sum(nn.relu(x), axis=-1, keepdims=True) + 1e-5),x)
    def __call__(self, attn_weights):
        results = []
        for i in range(self.alpha.shape[1]):
            alpha_i = self.alpha[:, i:i+1, :, :]
            x_i = attn_weights[:, i:i+1, :, :]
            result_i = self.choose_attention(alpha_i, x_i)
            results.append(result_i)

        return jnp.concatenate(results, axis=1)

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 12, 2023

It's hard to say without seeing your full code, but is it possible that your module is being executed in a vmap context? Note that vmap of cond becomes select, which will execute both branches (see jax.lax.cond).

@dhyani15
Copy link

Thanks for your comments.

I am modifying the huggingface flax vit implementation by changing class FlaxViTSelfAttention in this codebase. My new FlaxViTSelfAttention is below for your reference.
To answer your question with another question, does flax modules implicitly run subclasses in a vmap context? Because in the codebase I don't see any explicit calls to vmap.

class CustomFlaxViTSelfAttention(FlaxViTSelfAttention):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    
    def setup(self):
        # self.alpha = self.param('alpha', nn.initializers.ones, (1, self.config.num_attention_heads, 1, 1))
        self.chooseAttention = ChooseAttention(self.config)
        if self.config.hidden_size % self.config.num_attention_heads != 0:
            raise ValueError(
                "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:"
                " {self.config.num_attention_heads}"
            )
        # self.qkv = nn.Dense(self.config.hidden_size * 3,dtype=self.dtype,kernel_init=jax.nn.initializers.glorot_uniform(),use_bias=False)
        self.query = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
            ),
            use_bias=self.config.qkv_bias,
        )
        self.key = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
            ),
            use_bias=self.config.qkv_bias,
        )
        self.value = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
            ),
            use_bias=self.config.qkv_bias,
        )
        
    
    def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False):
        head_dim = self.config.hidden_size // self.config.num_attention_heads # 1,257,192
        query_states = self.query(hidden_states).reshape( # query dense 192 -> 192 (12*16)
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        
        value_states = self.value(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        
        key_states = self.key(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        

        dropout_rng = None
        if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
            dropout_rng = self.make_rng("dropout")


        query, key = promote_dtype(query_states, key_states, dtype=self.dtype)
        dtype = query.dtype

        assert query.ndim == key.ndim, 'q, k must have same rank.'
        assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
        assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
        assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'

        # calculate attention matrix
        depth = query.shape[-1]
        query = query / jnp.sqrt(depth).astype(dtype)
       
        attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key, precision='highest').astype(dtype)
        # vectorized_choose_attention = jax.vmap(choose_attention, in_axes=(1, 1), out_axes=1)
        # scalattn = (attn_weights/attn_weights.shape[2]).astype(dtype) # scale attention
        # attn_weights = nn.relu(attn_weights).astype(dtype) 
        # attn_weights = attn_weights / (jnp.sum(attn_weights, axis=-1, keepdims=True) + 1e-5).astype(dtype) # RElusoftmax
        # attn_weights = (self.alpha*attn_weights).astype(dtype) + ((1-self.alpha)*scalattn).astype(dtype) # weighted sum
        # attn_weights = vectorized_choose_attention(self.alpha,attn_weights)
        attn_weights = self.chooseAttention(attn_weights)


        if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
            keep_prob = 1.0 - self.config.attention_probs_dropout_prob
            if True:
                # dropout is broadcast across the batch + head dimensions
                dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
                keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)  # type: ignore
            else:
                keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)  # type: ignore
            multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
            attn_weights = attn_weights * multiplier

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
        attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))

        outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
        return outputs

@howsiyu
Copy link

howsiyu commented Sep 18, 2023

If vmap of cond becomes select, is there any benefit of doing

jax.vmap(lambda x: jax.lax.cond(x > 0, some_expensive_function, lambda _: 0, x))(xs)

vs

jax.vmap(lambda x: jax.lax.cond(x > 0, lambda: some_expensive_function(x), lambda: 0))(xs)

?

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 18, 2023

Those two expressions will lower to the same sequence of operations. You can check this by printing their jaxpr:

import jax
import jax.numpy as jnp
some_expensive_function = jnp.sin  # stand-in

f1 = jax.vmap(lambda x: jax.lax.cond(x > 0, some_expensive_function, lambda _: 0.0, x))
f2 = jax.vmap(lambda x: jax.lax.cond(x > 0, lambda: some_expensive_function(x), lambda: 0.0))

x = jnp.arange(10.0)
print(jax.make_jaxpr(f1)(x))
print(jax.make_jaxpr(f2)(x))

Both print this:

{ lambda ; a:f32[10]. let
    b:bool[10] = gt a 0.0
    c:i32[10] = convert_element_type[new_dtype=int32 weak_type=False] b
    d:bool[10] = eq c 0
    e:f32[10] = stop_gradient a
    _:f32[10] = select_n d e a
    f:f32[10] = broadcast_in_dim[broadcast_dimensions=() shape=(10,)] 0.0
    g:bool[10] = eq c 1
    h:f32[10] = stop_gradient a
    i:f32[10] = select_n g h a
    j:f32[10] = sin i
    k:f32[10] = select_n c f j
  in (k,) }

@howsiyu
Copy link

howsiyu commented Sep 18, 2023

So I guess if I want to avoid computing expensive function on inputs that are not needed the only way is to do

ys = jax.vmap(expensive_function)(xs[cond])
zs = jax.vmap(cheap_function)(xs[~cond])

?

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 18, 2023

Yes, that will work. You can do so with lax.cond as well, just not within a vmap.

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 18, 2023

(but note that xs[cond] and xs[~cond] cannot be evaluated within jit or any other jax transformation because the output shapes are data-dependent)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. documentation question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests