Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JEP: Type Annotations #11859

Merged
merged 1 commit into from
Sep 13, 2022
Merged

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Aug 11, 2022

This Jax Enhancement Proposal (JEP) introduces a roadmap for type annotations in JAX, addressing a number of goals, non-goals, and design decisions that need to be made.

Part of #12049.

@jakevdp jakevdp self-assigned this Aug 11, 2022
@jakevdp jakevdp marked this pull request as draft August 11, 2022 22:21
@jakevdp jakevdp added the JEP JAX enhancement proposal label Aug 11, 2022
@NeilGirdhar
Copy link
Contributor

I'm really happy to see this level of care about type annotations! This is a great design document. I love the division into three levels.

Just a few comments that are hopefully useful:

Flexible annotation for decorators has been a long-standing issue in the mypy package, which was only recently resolved by the introduction of ParamSpec in PEP 612, available starting in Python 3.10.
Because JAX follows NEP 29, it cannot rely on Python 3.10 features until sometime after mid-2024.

Do we have to wait for Python 3.10? ParamSpec is available as part of typing-extensions. Jax already uses typing-extensions in various places.

When functions are decorated by jax transformations like jit, vmap, grad, etc. JAX will strip all annotations. The reason for this is that without the mechanisms of PEP 612 there is no good way to do otherwise, and ParamSpec will not be available for use until Python 3.10.

This is an important point. I have been waiting for MyPy to fix their ParamSpec bugs before submitting a fix for the annotation stripping. The MyPy team is rolling out fixes very often, so I'm not even sure, but the fix may already work.

and functions returning a shape should always be Tuple[int]

Minor (common) error, this should be Tuple[int, ...].

@jakevdp
Copy link
Collaborator Author

jakevdp commented Aug 16, 2022

Do we have to wait for Python 3.10? ParamSpec is available as part of typing-extensions. Jax already uses typing-extensions in various places.

The sentence after the one you quoted mentions this possibility 😁

I think that given how in-flux this is (recently added, not fully supported by mypy, etc.) we can leave this kind of annotation for a later data once things have stabilized. What do you think?

@jakevdp
Copy link
Collaborator Author

jakevdp commented Aug 16, 2022

@NeilGirdhar – I edited the doc a bit to address some of those comments - let me know what you think

@jakevdp
Copy link
Collaborator Author

jakevdp commented Aug 17, 2022

@8bitmp3 Thanks for the comments! I addressed them in the most recent commit.

@NeilGirdhar
Copy link
Contributor

The sentence after the one you quoted mentions this possibility grin

My mistake!!

I think that given how in-flux this is (recently added, not fully supported by mypy, etc.) we can leave this kind of annotation for a later data once things have stabilized. What do you think?

Yes, of course!

@jakevdp jakevdp requested a review from froystig August 17, 2022 20:15
@jakevdp jakevdp marked this pull request as ready for review August 17, 2022 20:15
jakevdp added a commit to jakevdp/jax that referenced this pull request Sep 9, 2022
As part of this change, I created a helper function so that the logic of type checking is
in a single location. Eventually we can replace this helper function with appropriate
isinstance() checks using the APIs described in jax-ml#11859.
jakevdp added a commit to jakevdp/jax that referenced this pull request Sep 9, 2022
As part of this change, I created a helper function so that the logic of type checking is
in a single location. Eventually we can replace this helper function with appropriate
isinstance() checks using the APIs described in jax-ml#11859.
Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice improvements since my last read!

@jakevdp
Copy link
Collaborator Author

jakevdp commented Sep 13, 2022

Thanks - I'm squashing the commits and going to merge.

Thanks all for the discussion & comments on this! I think we have a nice framework to use going forward.

Copy link
Collaborator

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I like where we ended up on this.

@copybara-service copybara-service bot merged commit a2930e6 into jax-ml:main Sep 13, 2022
@jakevdp jakevdp deleted the jep-type-annotation branch September 13, 2022 17:52
@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Sep 28, 2022

With the last Jax, release it looks like we're using jax.Array now! I'm just curious, is there a plan to make Array generic the same way that numpy arrays are generic? That is, support something like:

from jax import Array
BooleanArray = Array[jnp.bool_]
IntegralArray = Array[jnp.integer[Any]]
RealArray = Array[jnp.floating[Any]]
ComplexArray = Array[Union[jnp.floating[Any], jnp.complexfloating[Any, Any]]]
KeyArray = Array[jax.random.PRNGDtype]

@jakevdp
Copy link
Collaborator Author

jakevdp commented Sep 28, 2022

It would be nice, but it's not currently in scope: variadic type generics are still not fully supported by mypy and other type checkers. This is briefly addressed here: https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html#avoid-unstable-typing-mechanisms

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Sep 28, 2022

variadic type generics

I agree with avoiding unstable typing mechanisms, but, if I understand correctly, you don't need variadic type generics to have a generic dtype. I think you only need them for the shape.

@patrick-kidger
Copy link
Collaborator

patrick-kidger commented Sep 28, 2022

Unfortunately, using variadic generics for shape annotations doesn't really work. See my previous discussions on the topic here and here. (Both these links cover a lot of the same points.)

TL;DR: with a lot of effort it might be possible to get something very limited. But that's such an unattractive value proposition that no-one has been interested in making that happen.

If you want dtype/shape annotations then for the forseeable future I think your best bet is jaxtyping.

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Sep 28, 2022

Unfortunately, using variadic generics for shape annotations doesn't really work.

Right, but I'm not suggesting using shape annotations. I'm only wondering whether we could have dtype annotations, which numpy already exposes and both Mypy and Pyright are able to check. Dtype annotations have already saved me many times, so it would be nice to keep them as I switch over to jax.Array.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
JEP JAX enhancement proposal pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants