-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vmap of cond's predicate results in select, leading to unexpected compute/memory use #8409
Comments
In fact, spamming the idempotent |
XLA may choose to evaluate several or all branches of But XLA has no "batched" conditional construct. Today, when batching the predicate to a Could |
I see. Is there any reason why vmap'd conds are not supported? What is the logic there? I think this would be very important for a lot of people (and it would be great if these nuances were documented). I don't know the internals of the jax compiler, but I do know that other similar systems support some form of branching. Is there any way to add this capability to XLA or somehow jax? Actually - since I could, if I were carefully about my code, manually batch everything, including conditionals, I see no reason why this couldn't be supported with vmap? I did try |
Often, one branch of |
A loop may be possible in some cases, but not all, and I think that makes this problematic. And, if I can, I would like to make an argument for why I think somehow supporting The wonderful thing about In some cases (as I would argue here), |
We're in agreement here by and large. This is something that we've thought about improving before, whether at the JAX or XLA level. I can't find an open issue for it, so let's use this one for it. |
Hi there! I was wondering what the JAX team's latest thinking is regarding the behavior of lax.cond when batched via vmap. I find myself often running into the design pattern of conditionally branching into two subroutines, one expensive, and the other a "placeholder," for example, returning a dummy zero tensor. |
@minqi – the thinking hasn't changed much since this issue was last active. Although there's a fundamental puzzle regarding whether/how to do better, for now we're still producing |
I find this one of the biggest practical issues with jax -- |
Btw does |
Switch is implemented in terms of |
That's what I thought -- might be helpful to document that for |
Hi, I was wondering what the status was on that ? I face the following situation: jax.vmap(lambda x,y: lax.cond(y<0, heavy_computation_1, heavy_computation_2,x))(X, Y) IIUC, i'll execute both branches in this case and I would rather not :) |
This is still the case, as mentioned in the docstring of |
I don't know of any active work to change this. |
Correct me if I'm wrong but I think you can get around this by splitting the vmap dimension into a list of single pieces, |
Hey, thanks for the suggestion, I create a MWE that I think illustrates your idea. Let me know if that is not the case. This example batches the condition as well as the argument of import jax
from jax import lax
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import random
X = jnp.arange(1000*4).reshape(1000, 4)
sigma = random.normal(random.PRNGKey(0), (1000,)) + 3.
A = random.normal(random.PRNGKey(1), (1000, 4))
def true_f(x):
return A@x
def false_f(x):
return -A@x
def f(sigma, x):
return lax.cond(sigma < 3, true_f, false_f, x)
@jax.jit
def F(SIGMA, X):
return jax.vmap(lambda sigma, x: f(sigma, x))(SIGMA,X)
@jax.jit
def splitF(SIGMA, X):
split_sigma = jnp.split(SIGMA, SIGMA.shape[0])
split_x = jnp.split(X, X.shape[0])
return jnp.stack(jtu.tree_map(lambda sigma, x: f(sigma[0], x[0]), split_sigma, split_x))
print(F(sigma, X))
print("-----------")
print(splitF(sigma, X)) The compilation of %timeit F(sigma, X).block_until_ready()
>>> 3.01 ms ± 177 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) %timeit splitF(sigma, X).block_until_ready()
>>> 2.81 ms ± 93.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) I'm not sure if we can draw meaningful conclusions from this small example but it's a starting point. |
@pablo2909 Yeah the compilation time would take longer but it should allow taking advantage of cond only computing one branch during run, so it depends what you’d like to take advantage of. |
Here's an example that shows the differences for true and false functions with highly skewed compute times, 20 qr decomps vs a single elementwise multiply:
The times on my machine are:
The unroll has several times longer compile time, but only about a quarter the run time, so if it's going to be run for many iterations the compile time would be worth it. In the case where true and false functions are both light, or can be computed simultaneously on a gpu, it might be worth it to just use vmap, you'd have to experiment. |
Thanks a lot, I'm a bit surprised by your results. Your To adapt it a bit more to my original question, I make from time import time
import jax
from jax import numpy as jnp, jit
def true_fn(rng_key, x):
for _ in range(20):
x = jnp.linalg.qr(x)[0]
rng_key, subkey = jax.random.split(rng_key)
y = jax.random.uniform(subkey, (200, 200))
x = x + y
return x
def false_fn(rng_key, x):
for _ in range(20):
x = jnp.linalg.qr(x)[0]
rng_key, subkey = jax.random.split(rng_key)
y = jax.random.uniform(subkey, (200, 200))
x = x - y
return x
def main_fn(carry, rng_key_x):
rng_key, x = rng_key_x
rng_key, subkey = jax.random.split(rng_key)
c = jax.random.choice(subkey, 2).astype(bool)
return carry, jax.lax.cond(
c,
true_fn,
false_fn,
rng_key,
x,
)
@jit
def regular_vmap(rng_key, xs):
rng_keys = jax.random.split(rng_key, xs.shape[0])
return jax.vmap(main_fn, in_axes=(None, 0))(None, (rng_keys, xs))[1]
@jit
def unrolled(rng_key, xs):
rng_keys = jax.random.split(rng_key, xs.shape[0])
_, x_out = jax.lax.scan( main_fn, None ,(rng_keys, xs))
return x_out
rng_key = jax.random.PRNGKey(0)
x_in = jax.random.normal(jax.random.PRNGKey(1), (15, 200, 200))
# compile
t = time()
regular_vmap(rng_key, x_in).block_until_ready()
print("regular_vmap compile", time() - t)
t = time()
unrolled(rng_key, x_in).block_until_ready()
print("list_vmap compile", time() - t)
n = 5
x = x_in
t = time()
for _ in range(n):
rng_key, subkey = jax.random.split(rng_key)
x = regular_vmap(subkey, x).block_until_ready()
print("regular_vmap", (time() - t) / n)
x = x_in
t = time()
for _ in range(n):
rng_key, subkey = jax.random.split(rng_key)
x = unrolled(subkey, x).block_until_ready()
print("list_vmap", (time() - t) / n) >>> regular_vmap compile 7.062215089797974
>>> list_vmap compile 4.893687963485718
>>> regular_vmap 2.9101763725280763
>>> list_vmap 1.1428837776184082 |
Scanning over the inputs is a good idea, I think that was mentioned before in a similar issue, it sacrifices a bit of runtime for faster compile (you could of course also do this for the qr for loop) Edit: Actually after testing this setup the scan-over-batch version runs faster than the unroll! |
Unfortunately unrolling using scan or list comprehension only seems to consistently improve performance on cpu, not gpu or tpu, unless the batch size is very small and the branches are wildly unequal in compute. So I don't think there's a good solution to this problem without batched cond through XLA :( |
I've been observing the same on gpu :/ |
I'm not familiar with triton at all but maybe there's a way to batch cond through pallas. I don't have enough of a need to look that deep into it though 😁 |
Hello, |
I have had the same issue as @minqi here. Just wanted to express that this is still an issue that people deal with |
Because this hasn't been mentioned yet: As far as I know, using Regarding the It would be nice to have the option to disable vmap completely, i.e. raise an Exception when jax tries to convert a vmap cond to a select. |
It might be worth mentioning that it is possible to use |
@inversecrime would sequential_vmap be similar to unrolling the batch using a for loop? |
It generates a |
I tried using |
Indeed, it all depends on the program (ignoring possible compiler rewrites), i.e. what's being computed within each branch. The two approaches trade off total compute vs. parallelism. For similar reasons, there are many possible implementations of "batched cond," all along this tradeoff curve. |
I have been playing around with converting diffmpm from the difftaichi package into a jax version, and while the forward pass has been working wonderfully, the backward pass has been using way too much GPU memory.
Today, I was able to track down that memory usage to the grid op. The grid op step is a series of nested if statements. At first, I was using jnp.where, which evaluates all branches. That is extremely inefficient and can lead to OOM errors. I simplified my code, and switch to jnp.cond, but my only conclusion is that cond is also evaluating both branches, otherwise I cannot see why this would run into OOM issues.
Below is a modified version of the grid op, that is composed into itself 4,000 times, like a simulation. Even run with the XLA_PYTHON_CLIENT_PREALLOCATE=false flag, this quickly leads to the the whole GPU being used, and more if the loop length is increased. This is not true if every line from
lin = ....
until right before the return of grid_op is commented out. In that case, memory usage is practically negligible. Note that because bound = 0, literally every line writtenv_out = jax.lax.cond ...
evaluates to False by definition, and so most of the expressions, including the v_out_gate's and their dependencies, shouldn't even need to be evaluated in the jitted function.Maybe I am misunderstanding cond; if so, what is the proper way to get this sparse branching behavior? I don't want to evlauate and hang onto a bunch of expensive tensors that are never actually needed and crash my GPU with OOM, especially in an backward pass. This is a core bottleneck to practical deployment of my code and a feature that I think should be supported. FWIW, I am using Version: 0.1.69+cuda101
Code to reproduce is below.
The text was updated successfully, but these errors were encountered: