Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

filter_value_and_grad: preserve return type of wrapped function #557

Conversation

ConnorBaker
Copy link
Contributor

@ConnorBaker ConnorBaker commented Oct 14, 2023

I ran into this while trying to update the RNN tutorial to use jaxtyping.

Using a bounded type variable conveys to the type checker that wrapping a function with filter_value_and_grad will not change the type of value returned by the original function.

As an example:

Suppose f returns a float. Let wrapped = filter_value_and_grad(f). Then the return type of wrapped is tuple[_Scalar, PyTree], not tuple[float, PyTree], where _Scalar is defined here:

_Scalar = Union[float, complex, Float[ArrayLike, ""], Complex[ArrayLike, ""]]

The wrapped function can return a value of a different type so long as it is also a _Scalar (like complex).

Using the type variable communicates this invariant to the type checker.

I opted not to replace _Scalar with _ScalarTy throughout the entire file because it doesn't make sense to use a type variable which occurs exclusively arguments or the return value, and there were a few other use sites which do that with _Scalar.

@patrick-kidger
Copy link
Owner

Hmm, looks like checks are failing just for formatting. (See the contributing guide for how to have these checks run automatically locally, including automatically formatting.)

Other than that, this LGTM -- thank you for the fix.

@ConnorBaker
Copy link
Contributor Author

ConnorBaker commented Oct 15, 2023

That’s odd, I had installed and run them locally to make sure. I wonder if I have some global config set somewhere…

EDIT: Nope, just forgot that pre-commit looks at staged files. Fixed it.

Although, when I ran black . I did notice a number of files were re-formatted :x

@ConnorBaker ConnorBaker force-pushed the fix/filter_value_and_grad-preserves-scalar-type branch from 1359768 to 63ff395 Compare October 15, 2023 19:36
@patrick-kidger patrick-kidger merged commit 53d2fb4 into patrick-kidger:main Oct 16, 2023
@patrick-kidger
Copy link
Owner

patrick-kidger commented Oct 16, 2023

Alright, LGTM! Merged. Thank you for contributing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants