Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cond does not shortcut in recursions #13796

Closed
gboehl opened this issue Dec 26, 2022 · 2 comments
Closed

cond does not shortcut in recursions #13796

gboehl opened this issue Dec 26, 2022 · 2 comments
Labels
bug Something isn't working

Comments

@gboehl
Copy link

gboehl commented Dec 26, 2022

Description

A follow up on #3103. I am using cond in a function recursion because I have to iterate on sub-parts of matrices. However, in this case cond is evaluating both statements, leading to an error:

def return_x(x):
    return 0

def return_z(x):
    x = x.at[0].set(0)
    x = x.at[1:].set(rec(x[1:]))
    return x
    # return jnp.ones_like(x)

@jax.jit
def rec(x):
    return jax.lax.cond(len(x) > 1, return_z, lambda x: 0, x)
    # return cond(len(x) > 1, return_z, return_x, x)

rec(jnp.arange(30))

This results in IndexError: index is out of bounds for axis 0 with size 0.

Is there any way to circumvent this?

Thanks for all this great work!

What jax/jaxlib version are you using?

jax/jaxlib=0.3.24

Which accelerator(s) are you using?

CPU

Additional system info

Python 3.10.8, Linux 6.0.12

@gboehl gboehl added the bug Something isn't working label Dec 26, 2022
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 26, 2022

Hi - thanks for the question! lax.cond does its branching at runtime, not trace-time. At trace-time, all branches must be traced, because the value of the conditional is not known until runtime. This means that it cannot be used in a recursive program in order to break the loop.

Fortunately in this case, you're branching based on static attributes (the length of the array) so you can use normal Python conditionals instead of lax.cond:

@jax.jit
def rec(x):
    return return_z(x) if len(x) > 1 else 0

@gboehl
Copy link
Author

gboehl commented Dec 26, 2022

Great, thank you for clarifying and the super fast response!

@gboehl gboehl closed this as completed Dec 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants