Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC: clarify behavior of lax.cond & lax.select #13589

Merged
merged 1 commit into from
Jan 7, 2023

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Dec 9, 2022

In response to questions at #13586; also #7934

Related to #8409

@skirsten
Copy link

skirsten commented Dec 9, 2022

I also just found #7934 which will probably be resolved by this.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Dec 10, 2022

Thanks for the pointer!

@jakevdp
Copy link
Collaborator Author

jakevdp commented Dec 19, 2022

@froystig – Updated based on offline discussion - please take a look!

@bheijden
Copy link

bheijden commented Dec 23, 2022

I've been noticing super linear compilation times in my application when using jax.lax.switch with an increasing number of branches (similar to this issue). From what I read here, jax.lax.cond is implemented as a special case of switch.

Could these lines explain the increased compilation time? You state that XLA may nevertheless choose to use a select instead of branching if it is deemed advantageous, but could this process of inference lead to super linear compilation times?

In my application, I chose to use a switch instead of a select because I feared long compilation times, but your comments seem to suggest that under the hood XLA may be ignoring this. If this is indeed the case, it would be great if there was an option to force jax/xla to forgo on this optimization to lower the compilation times.

jax/_src/lax/control_flow/conditionals.py Outdated Show resolved Hide resolved
jax/_src/lax/control_flow/conditionals.py Outdated Show resolved Hide resolved
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jan 6, 2023
copybara-service bot pushed a commit that referenced this pull request Jan 6, 2023
FUTURE_COPYBARA_INTEGRATE_REVIEW=#13589 from jakevdp:cond-doc c9c6263
PiperOrigin-RevId: 499899435
copybara-service bot pushed a commit that referenced this pull request Jan 6, 2023
FUTURE_COPYBARA_INTEGRATE_REVIEW=#13589 from jakevdp:cond-doc c9c6263
PiperOrigin-RevId: 499899435
@copybara-service copybara-service bot merged commit 01f9934 into jax-ml:main Jan 7, 2023
@jakevdp jakevdp deleted the cond-doc branch January 7, 2023 00:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants