-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add documentation on how to use PyType with Jax (and also common add-on libraries such as Flax) #8224
Comments
Thanks for raising this: I think at this point the best practices generally involve using Additionally, type-checking of JAX code is hampered by the fact that mypy/pytype has poor support in general for decorators (see e.g. python/mypy#1927), which JAX code tends to use extensively. Beyond the recommendation to use |
@superbobry any thoughts or references to point to? (@billmark is a Googler.) |
I think we can have a proper type for arrays in JAX, and I have a draft commit doing that internally. Trees, however, are harder, because their type is fundamentally recursive and generic, and pytype does not currently support that. I think mypy does have some support for recursive generics, but I'm sure it too has its limits. |
There are a bunch of confusing comments in the Flax documentation and code regarding what to do for type annotations on PyTrees. For example: https://flax.readthedocs.io/en/latest/flax.struct.html says "Note: Inherit from PyTreeNode instead to avoid type checking issues when using PyType". These comments imply that there is some best practice for type annotations, but I can't find any coherent explanation of what that best practice is. Maybe it really is to use "Any" everywhere because proper type checking is impossible, but if so that needs to be stated clearly instead of having confusing half-explanations that imply some other solution scattered elsewhere in the docs. This seems to be an area of general confusion, as this discussion is already uncovering. For example, there are bugs like this one: google/flax#620 |
I don't know if it's been considered, but one other option are run-time type-checkers. The main two options I know about are typeguard and beartype. c.f. also torchtyping for PyTorch tensor annotations as an example of what you can do with them. |
@jakevdp. I realize I didn't directly reply to your question:
If those are the best practices, then that's mostly sufficient. Minor additions would include: (1) please remove or amend the other confusing comments about type checking in the Flax docs. (2) Explain why anything better than "Any" isn't possible with the current type checkers. (3) Possibly discuss the use of explicit type annotations every time you perform a functional transformation as a workaround for the fact that functional transformations break type checking. (i.e. my_var: actual_type = jax.jit(blah, blah)). However, it seems from other comments (and also from the Flax docs) that there's not consensus even by Jax/Flax maintainers on what the best practices should be. I am not prepared to weigh in on that discussion. |
@beartype @beartype was originally gestated out of a multiphysics biology simulator, where runtime type-checking tamed the million-line code beast that nothing else could. We're still as devoted to big data science now as we were back then – and JAX is directly in that wheelhouse. @beartype only currently provides explicit support for NumPy type hints like Let @beartype know if we can do anything for JAX. Until then, thanks for all the efficient transforms, wonderful JAX team! |
Hi - JAX does not currently do much with static typing, beyond some scattered uses of |
FWIW, once the jaxtyping rewrite goes in ¹ then jaxtyping will actually be PEP-compliant. It shouldn't actually need any special support from either runtime type checkers or static type checkers. 'tis a thing of beauty, if I say so myself. ¹ Once @jakevdp and I have settled our differences regarding that pesky |
🥳
Sadly, this is the way. Exactly as you suggest, @jakevdp, Python's NumPy circumvents this by cleverly piggybacking its Cue the sign for victory. \o/ |
There are a bunch of tricks that one needs to know to use PyType with JAX (esp. also in combination with Flax). For example, a PyTree needs to be treated as "Any".
Since it's very common to want to use PyType with JAX, it would be useful to have a section of the Jax documentation summarizing these tricks and best practices. I'm not sure what the best way is to handle the Jax/Flax interactions but it's important for someone to figure out how to document those best practices too.
The text was updated successfully, but these errors were encountered: