Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What is a good dummy value? #588

Closed
pablo2909 opened this issue Nov 14, 2023 · 8 comments
Closed

What is a good dummy value? #588

pablo2909 opened this issue Nov 14, 2023 · 8 comments
Labels
question User queries

Comments

@pablo2909
Copy link

pablo2909 commented Nov 14, 2023

Hi,

I'm asking a question following the recent blog post where point 9 raises awareness on lax.cond with jax.vmap. One of the advice is to use a dummy value, which I understand as such:

# Instead of

jax.vmap(lambda x: lax.cond(cond, true_f, false_f, x))(X)

# Do:
true_x = jax.vmap(lambda x: lax.cond(cond, true_f, dummy_f, x))(X)
false_x = jax.vmap(lambda x: lax.cond(cond, dummy_f, false_f, x))(X)
combine(true_x, false_x)

If we do that, then there we are safe from doing unwanted evaluation I believe. However, what is a good dummy value ? I would have imagine None would be, but if we store it in an array it is converted as a jnp.nan. Which could be fine, but then if nan arises during training it becomes hard to debug. And in my case true_f and false_f can return any real value, so it's not an option to use 0. or something similar.

Any advice on that?

Otherwise thanks a lot Patrick for the blog post !

@patrick-kidger
Copy link
Owner

patrick-kidger commented Nov 14, 2023

So typically one has to know the computation to be able to pick the dummy value -- in general you can't know what a safe value is. The typical pattern is loosely something this:

def true_fn(x):
    return 0

def false_fn(x):
    return 1 / x  # something that explodes at zero

def f(x):
    pred = x == 0
    safe_x = jnp.where(pred, 1, x)
    return lax.cond(pred, true_fn, false_fn, safe_x)

In this case, we've set safe_x to a value that can be used on both branches. It is now safe to run jax.vmap(f).

To add a bit more colour: this "safe dummy value" is most important in two cases:

  • when one branch would hit an infinite loop. Otherwise your vmap'd computation never ends.
  • when one branch would produce a NaN (e.g. division by zero as in the above example) and you are also backpropagating. (More on this case here.)

@patrick-kidger patrick-kidger added the question User queries label Nov 14, 2023
@pablo2909
Copy link
Author

pablo2909 commented Nov 14, 2023

Hmm let me detail a bit more my case then, and maybe rephrase my question

No branch would produce an infinite loop or nan values, I would just like to avoid doing some heavy computation twice.

# Instead of

jax.vmap(lambda x: lax.cond(cond, heavy_computation_1, heavy_computation_2, x))(X)

# I Do:
X1 = jax.vmap(lambda x: lax.cond(cond, heavy_computation_1, no_computation, x))(X)
X2 = jax.vmap(lambda x: lax.cond(cond, no_computation, heavy_computation_2, x))(X)
combine(X1, X2)

Am I right in doing the second option to avoid extra computation ? Or am I misunderstanding something ? It feels that the first case will do computation twice while the second case only once, but I'm a bit unsure from your explanation now.

Edit: And the no_computation branch returns a pytree of the same shape as heavy_computation_* but filled with Nan values, which is how I detect them and proceed. But I feel this is not a good practice

Thanks for the help :)

@pablo2909
Copy link
Author

Oh I just realized that both cases are equivalent in terms of computation.. right ?

@patrick-kidger
Copy link
Owner

Actually, all of your examples will still perform conditional computation, i.e. be efficient. The reason is that you've only vmap'd the argument x. It is specifically when the predicate cond is batched that a lax.cond turns into a jnp.where.

@pablo2909
Copy link
Author

pablo2909 commented Nov 14, 2023

Sorry this example is closer to my actual use case

jax.vmap(lambda x,y: lax.cond(y<0, heavy_computation_1, heavy_computation_2,x))(X, Y)

In that case, if I understand correctly, it will be converted to jnp.where and both branched will be executed. And there's no way around it, right ?

Now that you pointed out that this happens only when cond is batched , it all makes more sense. I missed that important point all this time, thanks a lot for the clarification :)

@patrick-kidger
Copy link
Owner

patrick-kidger commented Nov 15, 2023

Yup, pretty much.

There is one important exception that does come up occasionally: if you expect the entire predicate to be True across the whole batch, or False across the whole batch, then in-principle you could run just one branch. (For example, this comes up in some differential equation solvers: the "expensive branch" is making a numerical step, the "cheap branch" is to do nothing, and you want to keep iterating until every batch element has finished making steps. At the end you'll get a False predicate across the whole batch, and from that point onwards only need to make the cheap evaluations until the end of your loop.)

In this case you can make use of a trick, namely eqx.internal.unvmap_{any, all}, which consumes a batch-of-predicates and applies any or all down the batch dimension to return an unbatched single predicate.
Use that the wrong way, of course, and you end up getting the wrong output: each batch element now interacts with the others. But if you feel like you know what you're doing, and happen to have an example of the above use-case, then this can be a useful trick to have: lax.cond(eqxi.unvmap_any(pred), ...).

@pablo2909
Copy link
Author

Ohh , that looks super useful thanks a lot for the tip ! :)

@pablo2909
Copy link
Author

pablo2909 commented Nov 20, 2023

Pointing out to this. Could be an alternative in some cases

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants