-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type annotations not working for DeviceArray #6743
Comments
Thanks for the report! This is a known issue, tracked all the way back since #943. It turns out it's quite difficult to satisfy mypy when it comes to annotating JAX arrays. If you have any ideas about how to address this, we're open to contributions. In the meantime, you might follow the strategy that we use frequently throughout the package: create the alias |
@jakevdp Oh! Sorry! Since it was a new error, I wasn't sure if it was another instance of #943. Admittedly, I don't know enough about the problem, but I hoped it might be as easy as pointing Anyway, in my particular case, I do need to use a direct pointer to Feel free to close this as a duplicate if it is one :) |
I would be careful registering a function to dispatch on Additionally, you should be aware that there are multiple flavors of These are some of the same concerns that make satisfying mypy nontrivial. |
Yeah, thanks. I dispatch on tracers too. :)
Fair enough. All I really want is to print out the elements in the arrays that I can print out, and print shapes for the other arrays. I've been tweaking it over the last year, but it seems to be working now. |
I see - that makes sense! In that case it sounds like your best bet is to tell mypy to ignore these annotations, because indeed Alternatively, you could dispatch on each of the |
Okay, I think I understand, but if (I'm not really fond of dispatching on private classes even if they're concrete.) |
I think there's a distinction here between "DeviceArray as xla operand" and "DeviceArray as duck type for numpy arrays". The former should not have numpy-like attributes, the latter should. This might argue that we need another base class in the hierarchy that defines the interface to the two DeviceArray implementations; I think the reason we haven't done that is that the two implementations are temporary; somewhere there's a TODO to unify them again. Would it serve your purposes to locally disable mypy type checking, and only use the |
Okay, that makes sense. Thanks for explaining it. I'll just block the MyPy errors! |
MyPy says that
jax.interpreters.xla.DeviceArray
has no attribute "shape", is not indexable, etc. It seems that it points toDeviceArrayBase
—notDeviceArray
, which has the annotations.I know that type annotations are a work-in-progress. This is just a placeholder issue :)
The text was updated successfully, but these errors were encountered: