-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix axis_index inside nested pmaps #4378
Conversation
00683b2
to
171b1c5
Compare
@mattjj I'm not sure what does |
name_idx = core.Primitive.bind(axis_index_p, axis_name=name) | ||
index += name_idx * inner_size | ||
inner_size *= psum(1, name) | ||
return index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can instead leave the bind rule as it was, and just update the axis_index translation rule to actually use axis_name (and look it up in its axis_env argument).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is not a fix for the bug, but adds the "Additionally, this adds support for querying the index along multiple
axes, which will be useful for gmap." part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see. Thanks for explaining. I think that makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So IIUC the idea here is to handle multiple axes in the bind rule, rather than the alternative of handling it in the translation rule and in process_axis_index
. Is that right? That seems reasonable, though the alternative seems reasonable too. I think the psum
primitive and its rules work more like the latter (i.e. allowing axis_name to be a list/tuple, ultimately calling into this helper to handle it), but we don't need to be consistent with that.
Looks like I broke this in #3370, and we didn't have test coverage for it! (Before #3370 the |
The previous translation rule has assumed that `axis_index` is always taken over the outermost axis in the `axis_env`, and was always producing the same output, no matter which axis has been specified. This fixes the translation rule to start taking the `axis_name` into account. Additionally, this adds support for querying the index along multiple axes, which will be useful for `gmap`.
The rule didn't specify the precision for the `np.arange` constant, which caused an accidental dtype promotion in X64 mode. Previously the error has luckicly been hidden behind a coerction that followed `axis_index` in that test, but the new implementation has surfaced it.
223eee5
to
8ac19c7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for finding and fixing this; I'm embarrassed that I added such a bad bug in #3370 (we think), and moreover that we didn't have test coverage for nested axis index (even though it appears in at least one docstring, and sometimes in our pmap demos!).
The previous translation rule has assumed that
axis_index
is alwaystaken over the outermost axis in the
axis_env
, and was always producingthe same output, no matter which axis has been specified. This fixes the
translation rule to start taking the
axis_name
into account.Additionally, this adds support for querying the index along multiple
axes, which will be useful for
gmap
.