Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for lax.pmean #929

Closed
jheek opened this issue Jun 26, 2019 · 5 comments
Closed

Add support for lax.pmean #929

jheek opened this issue Jun 26, 2019 · 5 comments
Assignees
Labels
enhancement New feature or request question Questions for the JAX team

Comments

@jheek
Copy link
Contributor

jheek commented Jun 26, 2019

When using pmap I often miss a parallel mean function lax.pmean(x, axis_name).
I guess it's not too complicated to write lax.psum / axis_size. However, there is currently also no way to retrieve the size of a parallel axis within a parallel computation.

@mattjj
Copy link
Collaborator

mattjj commented Jun 26, 2019

I think you can do lax.psum(x, axis_name) / lax.psum(1, axis_name), and also axis_size = lax.psum(1, axis_name). Can you try that and see if that works?

@hawkinsp hawkinsp added enhancement New feature or request question Questions for the JAX team labels Jul 2, 2019
@jheek
Copy link
Contributor Author

jheek commented Jul 4, 2019

That works fine. Still would be nice if there's a way to obtain the axis_size in some other way. I didn't benchmark but doesn't this compile into something that will do an additional AllReduce?

@mattjj
Copy link
Collaborator

mattjj commented Jul 4, 2019

Actually no, lax.psum(1, axis_name) won't generate any communication or even any compiled code at all. Instead JAX optimizes it just to extract the axis size, or more generally multiply the argument by the product of the sizes of the axis names you give it. In fact, the same is true more generally: for any collective (not just psum), if the argument isn't mapped (sharded) along the given axis_name or names (not just totally unmapped), JAX will compute the value with no communication (or maybe raise a NotImplementedError if we haven't covered all the primitives yet).

So in short, the computation being generated by pmean = lambda x, axis_name: lax.psum(x, axis_name) / lax.psum(1, axis_name) seems optimal. If we wrote a helper function to compute pmean and include it in lax then we'd just use that definition. The reason to make lax.psum(1, axis_name) the canonical way to compute the axis size is just to minimize the API surface.

What do you think?

@mattjj mattjj self-assigned this Jul 4, 2019
@lucasb-eyer
Copy link
Contributor

You could close this issue as it was implemented and merged in #2081, except that it somehow is missing from the reference ?
image

@jekbradbury
Copy link
Contributor

Thanks for noticing that oversight! #2778 will add pmean to the docs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

5 participants