Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vmap support for rng_bit_generator. #8117

Merged
merged 1 commit into from
Oct 7, 2021

Conversation

LenaMartens
Copy link
Contributor

Adds a naive batching rule which loops over the primitive, then stacks the results. This has the advantage that the vmapped result will be the same as calling the primitive n times, but is not very performant.

@google-cla google-cla bot added the cla: yes label Oct 6, 2021
@LenaMartens LenaMartens requested a review from froystig October 6, 2021 18:23
@froystig
Copy link
Member

froystig commented Oct 6, 2021

Side note that something like #7199 might have been useful here.

Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an "inline" loop that will emit N individual primitive calls. Is there a batch dimension length at which jax.lax.map would make sense instead? We can follow up after this PR if there's an answer to that question.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Oct 6, 2021
@froystig
Copy link
Member

froystig commented Oct 6, 2021

Should we give some kind of warning in the meantime, especially if the batch dimension is large?

@copybara-service copybara-service bot merged commit 8f0589f into jax-ml:main Oct 7, 2021
@LenaMartens
Copy link
Contributor Author

Thanks for the comments! I agree that emitting n primitive calls for large n is undesirable, I can follow-up with a jax.lax.map version (which as I understand it is just a scan)

I can maybe do some benchmarking if I get to it, but do you have an intuition as to what batch size is "too large" for the regular for loop? Or should we just always use a scan?

@froystig
Copy link
Member

froystig commented Oct 7, 2021

I'm not sure what's "too large" but it might depend on the backend/device. Conceptually we want a map and that alone might be reason enough to switch. One thing we could do is always use an explicit scan for the map and then tune its unroll parameter for performance. This has the additional advantage of not increasing the jaxpr size, only the HLO.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants