Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor custom_vjp use for type agreement
When defining a forward rule for `jax.custom_vjp(primal_fn)`, if `primal_fn` has output type `T` then we need the forward rule to have output type `(T, R)` for some `R`. That is, we need the first output of the forward rule to look like the full output of `primal_fn`. (Here the `R` values represent the 'residuals' computed on the forward pass to save for use on the backward pass.) This PR fixes a disagreement between `custom_vjp`-decorated functions and their corresponding forward rules. The disagreement caused some interesting behavior! Discussed on #7 Separately, I'm going to try to get JAX to raise a better error message in this case; the error message was some really confusing JAX-internals thing.
- Loading branch information