-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
handle mapped_invars correctly in more places (#2828)
fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
- Loading branch information
Showing
6 changed files
with
99 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters