Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[direct-linearize] shmap fixes #27015

Merged
merged 1 commit into from
Mar 14, 2025

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Mar 8, 2025

No description provided.

@mattjj mattjj added the pull ready Ready for copybara import and testing label Mar 8, 2025
@mattjj mattjj requested a review from dougalm March 8, 2025 16:02
@@ -1616,48 +1616,48 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun,
res_names = _all_newly_manual_mesh_names(mesh, auto, trace)

@as_hashable_function(closure=linearize_outs_thunk)
def primal_out_names_thunk():
def fwd_out_names_thunk():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed "primal" to "fwd", thinking primal : a -> b and fwd: a -> (b, r).

@mattjj mattjj force-pushed the direct-linearize-fixes-4 branch 2 times, most recently from 0f22a84 to 9676d0b Compare March 8, 2025 21:21
Comment on lines +1737 to +1752
)[len(res_reshaped):]
_, in_ct_names = partition_list(in_undef, in_names)
in_cts = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else x if rewrite
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto)))
for ns, x in zip(in_ct_names, in_cts)]
res_zeros = [ad_util.zero_from_primal(r) for r in res]
return merge_lists(in_undef, res_zeros, in_cts)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the meat veggies of the fix. The previous code, by zipping together in_names and out = ad.backward_pass(jaxpr_unknown.jaxpr, ...), assumed that the in_names (for inputs to the tangent function being transposed), including residuals, would match the order of the cotangents produced by ad.backward_pass, i.e. the order of the input binders to jaxpr_unknown.jaxpr. But while in_names matches the residual order for the original jaxpr, the partial_eval_jaxpr_nounits step is free to reorder residuals when producing res_reshaped and jaxpr_unknown. When those are reordered, we could get shape errors for entries of in_names and out that no longer corresponded; specifically, evaluating ad.Zero(_unshard_aval(mesh, ns, x.aval)) could fail loudly if ns had entries that exceeded the rank of x.aval, or perhaps could fail silently and produce downstream problems.

I'm guessing this issue only arose with direct-linearize because the tangent jaxpr is now produced by a DynamicJaxprTrace and we're processing it with a partial_eval_jaxpr_nounits, whereas without direct-linearize we would produce the jaxpr with partial_eval_jaxpr_nounits, which may have preserved ordering in a way we relied upon.

@mattjj mattjj force-pushed the direct-linearize-fixes-4 branch 5 times, most recently from 494f526 to 05a6bdf Compare March 14, 2025 21:38
@mattjj mattjj force-pushed the direct-linearize-fixes-4 branch from 05a6bdf to 174dcc7 Compare March 14, 2025 21:38
@copybara-service copybara-service bot merged commit b00a3a1 into jax-ml:main Mar 14, 2025
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant