Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor custom_vjp use for type agreement #8

Merged
merged 1 commit into from
Sep 29, 2022

Commits on Sep 29, 2022

  1. refactor custom_vjp use for type agreement

    When defining a forward rule for `jax.custom_vjp(primal_fn)`, if `primal_fn`
    has output type `T` then we need the forward rule to have output type `(T, R)`
    for some `R`. That is, we need the first output of the forward rule to look
    like the full output of `primal_fn`. (Here the `R` values represent the
    'residuals' computed on the forward pass to save for use on the backward pass.)
    
    This PR fixes a disagreement between `custom_vjp`-decorated functions and their
    corresponding forward rules.
    
    The disagreement caused some interesting behavior! Discussed on
    lucidrains#7
    
    Separately, I'm going to try to get JAX to raise a better error message in this
    case; the error message was some really confusing JAX-internals thing.
    mattjj committed Sep 29, 2022
    Configuration menu
    Copy the full SHA
    78ce0a9 View commit details
    Browse the repository at this point in the history