-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JEP: Type Annotations #11859
JEP: Type Annotations #11859
Conversation
56ff700
to
c1b07ab
Compare
I'm really happy to see this level of care about type annotations! This is a great design document. I love the division into three levels. Just a few comments that are hopefully useful:
Do we have to wait for Python 3.10?
This is an important point. I have been waiting for MyPy to fix their
Minor (common) error, this should be |
The sentence after the one you quoted mentions this possibility 😁 I think that given how in-flux this is (recently added, not fully supported by mypy, etc.) we can leave this kind of annotation for a later data once things have stabilized. What do you think? |
@NeilGirdhar – I edited the doc a bit to address some of those comments - let me know what you think |
c4ae178
to
ce350d0
Compare
@8bitmp3 Thanks for the comments! I addressed them in the most recent commit. |
My mistake!!
Yes, of course! |
dd03e5d
to
8a39949
Compare
8a39949
to
504076b
Compare
As part of this change, I created a helper function so that the logic of type checking is in a single location. Eventually we can replace this helper function with appropriate isinstance() checks using the APIs described in jax-ml#11859.
504076b
to
231ef3d
Compare
As part of this change, I created a helper function so that the logic of type checking is in a single location. Eventually we can replace this helper function with appropriate isinstance() checks using the APIs described in jax-ml#11859.
2da540e
to
19d02cf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice improvements since my last read!
19d02cf
to
358363e
Compare
Thanks - I'm squashing the commits and going to merge. Thanks all for the discussion & comments on this! I think we have a nice framework to use going forward. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I like where we ended up on this.
With the last Jax, release it looks like we're using from jax import Array
BooleanArray = Array[jnp.bool_]
IntegralArray = Array[jnp.integer[Any]]
RealArray = Array[jnp.floating[Any]]
ComplexArray = Array[Union[jnp.floating[Any], jnp.complexfloating[Any, Any]]]
KeyArray = Array[jax.random.PRNGDtype] |
It would be nice, but it's not currently in scope: variadic type generics are still not fully supported by mypy and other type checkers. This is briefly addressed here: https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html#avoid-unstable-typing-mechanisms |
I agree with avoiding unstable typing mechanisms, but, if I understand correctly, you don't need variadic type generics to have a generic dtype. I think you only need them for the shape. |
Unfortunately, using variadic generics for shape annotations doesn't really work. See my previous discussions on the topic here and here. (Both these links cover a lot of the same points.) TL;DR: with a lot of effort it might be possible to get something very limited. But that's such an unattractive value proposition that no-one has been interested in making that happen. If you want dtype/shape annotations then for the forseeable future I think your best bet is jaxtyping. |
Right, but I'm not suggesting using shape annotations. I'm only wondering whether we could have dtype annotations, which numpy already exposes and both Mypy and Pyright are able to check. Dtype annotations have already saved me many times, so it would be nice to keep them as I switch over to |
This Jax Enhancement Proposal (JEP) introduces a roadmap for type annotations in JAX, addressing a number of goals, non-goals, and design decisions that need to be made.
Part of #12049.