Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix axis_index inside nested pmaps #4378

Merged
merged 2 commits into from
Sep 22, 2020

Conversation

apaszke
Copy link
Collaborator

@apaszke apaszke commented Sep 22, 2020

The previous translation rule has assumed that axis_index is always
taken over the outermost axis in the axis_env, and was always producing
the same output, no matter which axis has been specified. This fixes the
translation rule to start taking the axis_name into account.

Additionally, this adds support for querying the index along multiple
axes, which will be useful for gmap.

@apaszke
Copy link
Collaborator Author

apaszke commented Sep 22, 2020

@mattjj I'm not sure what does nreps represent in the axis env, so I tried my best to retain the semantics of the translation rule while fixing it. Please make sure to read that part carefully. The part that's confusing for me is why do we divide the replica ID by nreps / prod(axis_env.sizes). Is it ever the case that we replicate the computation without extending the axis_env?

@apaszke apaszke requested a review from mattjj September 22, 2020 13:33
jax/interpreters/pxla.py Outdated Show resolved Hide resolved
name_idx = core.Primitive.bind(axis_index_p, axis_name=name)
index += name_idx * inner_size
inner_size *= psum(1, name)
return index
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can instead leave the bind rule as it was, and just update the axis_index translation rule to actually use axis_name (and look it up in its axis_env argument).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not a fix for the bug, but adds the "Additionally, this adds support for querying the index along multiple
axes, which will be useful for gmap." part.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Thanks for explaining. I think that makes sense.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So IIUC the idea here is to handle multiple axes in the bind rule, rather than the alternative of handling it in the translation rule and in process_axis_index. Is that right? That seems reasonable, though the alternative seems reasonable too. I think the psum primitive and its rules work more like the latter (i.e. allowing axis_name to be a list/tuple, ultimately calling into this helper to handle it), but we don't need to be consistent with that.

@mattjj
Copy link
Collaborator

mattjj commented Sep 22, 2020

Looks like I broke this in #3370, and we didn't have test coverage for it! (Before #3370 the axis_index bind rule correctly used axis_name, but #3370 deleted that and I didn't fix the translation rule.)

The previous translation rule has assumed that `axis_index` is always
taken over the outermost axis in the `axis_env`, and was always producing
the same output, no matter which axis has been specified. This fixes the
translation rule to start taking the `axis_name` into account.

Additionally, this adds support for querying the index along multiple
axes, which will be useful for `gmap`.
The rule didn't specify the precision for the `np.arange` constant,
which caused an accidental dtype promotion in X64 mode. Previously the
error has luckicly been hidden behind a coerction that followed
`axis_index` in that test, but the new implementation has surfaced it.
@apaszke apaszke added the pull ready Ready for copybara import and testing label Sep 22, 2020
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for finding and fixing this; I'm embarrassed that I added such a bad bug in #3370 (we think), and moreover that we didn't have test coverage for nested axis index (even though it appears in at least one docstring, and sometimes in our pmap demos!).

@copybara-service copybara-service bot merged commit 99ffcc4 into jax-ml:master Sep 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants