-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
One array type to rule them all! #943
Comments
I believe all of these types have |
@jekbradbury Hmm, that doesn't quite work for me. Here are the results from running pyre:
|
Oof, it can't infer through JAX's monkeypatching of numpy methods. I'm not sure what to do about that (IIRC the monkeypatching exists to avoid a dependency issue: core JAX array types need to be able to dispatch their operators to |
Ooh boy, I didn't realize the rabbit hole went that deep! Is there any way to create an abstract "fake" type to capture all of the things expected in core and then force things to occupy that fake type? |
I bet there is -- or we could maybe add the expected operators to core arrays with empty implementations, then override them in the monkeypatch (rather than adding them where none existed). |
@jekbradbury Yeah, either of these would be great for my use case. |
Is https://github.com/google/jax/blob/master/jax/numpy/lax_numpy.py#L84 where you'd add stubs? |
Probably, though I’m not super familiar with that part of the code. (CC @mattjj) |
Given the error message, that'd also be my best guess at where to add stubs. But I'm not familiar with any of this Python type checking machinery. JAX relies on a lot of overloading of user code; we have the ndarray (effective) subclasses DeviceArray and ShardedDeviceArray to represent concrete values, but those aren't actually what flow through user code when transforming it. What flows through user code are instances of the That's a lot of overloading and indirection for a type checker to reason about... If you figure anything out here it would be super beneficial to a lot of JAX users, and I'm eager to answer questions about it as best I can, but the core team doesn't have cycles to look into supporting type-checking systems for user code at the moment. What do you think? |
FWIW my solution here has been to begin a pyi module that is symlinked for both Shape = Tuple[int, ...]
class ndarray:
@property
def shape(self) -> Shape:
...
def __getitem__(self, key) -> Any:
...
def __add__(self, other) -> ndarray:
...
# and so on... That certainly doesn't solve the problem for |
Have you looked at https://github.com/numpy/numpy-stubs and https://github.com/ramonhagenaars/nptyping ? |
Hi, are there any updates on this? How difficult would it be do get the correct behavior? And in anticipation of this, is there a type that we could define that would satisfy the check? Would the following be enough?
|
It looks like this is going to be addressed in #11859 --- "Jax Enhancement Proposal (JEP)" (still in progress): https://jax--11859.org.readthedocs.build/en/11859/jep/12049-type-annotations.html
|
This is fixed these days! ( |
I use Python 3's typing features as much as possible. Unfortunately jax's value hierarchy makes this a little bit challenging. Consider the following snippet,
I'd like to be able to fill in the mystery
ArrayType
s with something like a made-up "jp.Array", but AFAICT from the array class hierarchy, there is no such type that really fits. At first glancejp.DeviceArray
looks like an eligible candidate, but then there is alsoConcreteArray
,ShapedArray
, andUnshapedArray
. I'm not really sure what the differences are between them but some of them derive fromjax.core.AbstractValue
, whileDeviceArray
does not... If there are other cases then I'd certainly like to avoid limiting my type signatures to only operating on arrays that live on-device. To make matters more confusing there also seems to be_FilledConstant
andDeviceConstant
:All in all, it's not clear to me how each of these types play together and how (if possible) to unify them. What's the appropriate type to be used here? And if it does not yet exist, could we create such a type?
The text was updated successfully, but these errors were encountered: