Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor custom_vjp use for type agreement #8

Merged
merged 1 commit into from
Sep 29, 2022

Conversation

mattjj
Copy link
Contributor

@mattjj mattjj commented Sep 29, 2022

When defining a forward rule for jax.custom_vjp(primal_fn), if primal_fn has output type T then we need the forward rule to have output type tuple[T, R] for some R. That is, we need the first output of the forward rule to look like the full output of primal_fn. (Here the R values represent the 'residuals' computed on the forward pass to save for use on the backward pass.)

This PR fixes a disagreement between custom_vjp-decorated functions and their corresponding forward rules. In particular, I applied the custom_vjp decorator to the functions which only have the primal outputs of interest (and not the extra residual/intermediate values like row_sum). (This choice means that our custom rule won't apply at higher orders of differentiation (since the rule doesn't itself call back into the custom_vjp-decorated function); but to make the higher-order case work we'd either need to not share the computation of residuals like row_sum with the primal computation (maybe XLA could DCE them?), or else we'd need to write a backward rule which handles nonzero gradients with respect to those intermediate outputs.)

The disagreement caused some interesting behavior! Discussed on #7

Separately, I'm going to try to get JAX to raise a better error message in this case; the error message was some really confusing JAX-internals thing.

When defining a forward rule for `jax.custom_vjp(primal_fn)`, if `primal_fn`
has output type `T` then we need the forward rule to have output type `(T, R)`
for some `R`. That is, we need the first output of the forward rule to look
like the full output of `primal_fn`. (Here the `R` values represent the
'residuals' computed on the forward pass to save for use on the backward pass.)

This PR fixes a disagreement between `custom_vjp`-decorated functions and their
corresponding forward rules.

The disagreement caused some interesting behavior! Discussed on
lucidrains#7

Separately, I'm going to try to get JAX to raise a better error message in this
case; the error message was some really confusing JAX-internals thing.
@lucidrains
Copy link
Owner

@mattjj thank you Matthew!

@lucidrains lucidrains merged commit 25df3b9 into lucidrains:main Sep 29, 2022
@mattjj mattjj deleted the custom-vjp-fix branch September 29, 2022 23:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants