filter_value_and_grad: preserve return type of wrapped function #557
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I ran into this while trying to update the RNN tutorial to use
jaxtyping
.Using a bounded type variable conveys to the type checker that wrapping a function with
filter_value_and_grad
will not change the type of value returned by the original function.As an example:
Suppose
f
returns afloat
. Letwrapped = filter_value_and_grad(f)
. Then the return type ofwrapped
istuple[_Scalar, PyTree]
, nottuple[float, PyTree]
, where_Scalar
is defined here:equinox/equinox/_ad.py
Line 102 in 56bafcb
The wrapped function can return a value of a different type so long as it is also a
_Scalar
(likecomplex
).Using the type variable communicates this invariant to the type checker.
I opted not to replace
_Scalar
with_ScalarTy
throughout the entire file because it doesn't make sense to use a type variable which occurs exclusively arguments or the return value, and there were a few other use sites which do that with_Scalar
.