-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[direct-linearize] shmap fixes #27015
Conversation
@@ -1616,48 +1616,48 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, | |||
res_names = _all_newly_manual_mesh_names(mesh, auto, trace) | |||
|
|||
@as_hashable_function(closure=linearize_outs_thunk) | |||
def primal_out_names_thunk(): | |||
def fwd_out_names_thunk(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed "primal" to "fwd", thinking primal : a -> b
and fwd: a -> (b, r)
.
0f22a84
to
9676d0b
Compare
)[len(res_reshaped):] | ||
_, in_ct_names = partition_list(in_undef, in_names) | ||
in_cts = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero | ||
else x if rewrite | ||
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto))) | ||
for ns, x in zip(in_ct_names, in_cts)] | ||
res_zeros = [ad_util.zero_from_primal(r) for r in res] | ||
return merge_lists(in_undef, res_zeros, in_cts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the meat veggies of the fix. The previous code, by zipping together in_names
and out = ad.backward_pass(jaxpr_unknown.jaxpr, ...)
, assumed that the in_names
(for inputs to the tangent function being transposed), including residuals, would match the order of the cotangents produced by ad.backward_pass
, i.e. the order of the input binders to jaxpr_unknown.jaxpr
. But while in_names
matches the residual order for the original jaxpr
, the partial_eval_jaxpr_nounits
step is free to reorder residuals when producing res_reshaped
and jaxpr_unknown
. When those are reordered, we could get shape errors for entries of in_names
and out
that no longer corresponded; specifically, evaluating ad.Zero(_unshard_aval(mesh, ns, x.aval))
could fail loudly if ns
had entries that exceeded the rank of x.aval
, or perhaps could fail silently and produce downstream problems.
I'm guessing this issue only arose with direct-linearize because the tangent jaxpr is now produced by a DynamicJaxprTrace
and we're processing it with a partial_eval_jaxpr_nounits
, whereas without direct-linearize we would produce the jaxpr with partial_eval_jaxpr_nounits
, which may have preserved ordering in a way we relied upon.
494f526
to
05a6bdf
Compare
05a6bdf
to
174dcc7
Compare
No description provided.