-
-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
What is a good dummy value? #588
Comments
So typically one has to know the computation to be able to pick the dummy value -- in general you can't know what a safe value is. The typical pattern is loosely something this: def true_fn(x):
return 0
def false_fn(x):
return 1 / x # something that explodes at zero
def f(x):
pred = x == 0
safe_x = jnp.where(pred, 1, x)
return lax.cond(pred, true_fn, false_fn, safe_x) In this case, we've set To add a bit more colour: this "safe dummy value" is most important in two cases:
|
Hmm let me detail a bit more my case then, and maybe rephrase my question No branch would produce an infinite loop or nan values, I would just like to avoid doing some heavy computation twice. # Instead of
jax.vmap(lambda x: lax.cond(cond, heavy_computation_1, heavy_computation_2, x))(X)
# I Do:
X1 = jax.vmap(lambda x: lax.cond(cond, heavy_computation_1, no_computation, x))(X)
X2 = jax.vmap(lambda x: lax.cond(cond, no_computation, heavy_computation_2, x))(X)
combine(X1, X2) Am I right in doing the second option to avoid extra computation ? Or am I misunderstanding something ? It feels that the first case will do computation twice while the second case only once, but I'm a bit unsure from your explanation now. Edit: And the Thanks for the help :) |
Oh I just realized that both cases are equivalent in terms of computation.. right ? |
Actually, all of your examples will still perform conditional computation, i.e. be efficient. The reason is that you've only vmap'd the argument |
Sorry this example is closer to my actual use case jax.vmap(lambda x,y: lax.cond(y<0, heavy_computation_1, heavy_computation_2,x))(X, Y) In that case, if I understand correctly, it will be converted to Now that you pointed out that this happens only when cond is batched , it all makes more sense. I missed that important point all this time, thanks a lot for the clarification :) |
Yup, pretty much. There is one important exception that does come up occasionally: if you expect the entire predicate to be In this case you can make use of a trick, namely |
Ohh , that looks super useful thanks a lot for the tip ! :) |
Pointing out to this. Could be an alternative in some cases |
Hi,
I'm asking a question following the recent blog post where point 9 raises awareness on
lax.cond
withjax.vmap
. One of the advice is to use a dummy value, which I understand as such:If we do that, then there we are safe from doing unwanted evaluation I believe. However, what is a good dummy value ? I would have imagine
None
would be, but if we store it in an array it is converted as ajnp.nan
. Which could be fine, but then if nan arises during training it becomes hard to debug. And in my casetrue_f
andfalse_f
can return any real value, so it's not an option to use0.
or something similar.Any advice on that?
Otherwise thanks a lot Patrick for the blog post !
The text was updated successfully, but these errors were encountered: