refactor custom_vjp use for type agreement #8
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
When defining a forward rule for
jax.custom_vjp(primal_fn)
, ifprimal_fn
has output typeT
then we need the forward rule to have output typetuple[T, R]
for someR
. That is, we need the first output of the forward rule to look like the full output ofprimal_fn
. (Here theR
values represent the 'residuals' computed on the forward pass to save for use on the backward pass.)This PR fixes a disagreement between
custom_vjp
-decorated functions and their corresponding forward rules. In particular, I applied thecustom_vjp
decorator to the functions which only have the primal outputs of interest (and not the extra residual/intermediate values likerow_sum
). (This choice means that our custom rule won't apply at higher orders of differentiation (since the rule doesn't itself call back into thecustom_vjp
-decorated function); but to make the higher-order case work we'd either need to not share the computation of residuals likerow_sum
with the primal computation (maybe XLA could DCE them?), or else we'd need to write a backward rule which handles nonzero gradients with respect to those intermediate outputs.)The disagreement caused some interesting behavior! Discussed on #7
Separately, I'm going to try to get JAX to raise a better error message in this case; the error message was some really confusing JAX-internals thing.