Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotations not working for DeviceArray #6743

Closed
NeilGirdhar opened this issue May 13, 2021 · 8 comments
Closed

Type annotations not working for DeviceArray #6743

NeilGirdhar opened this issue May 13, 2021 · 8 comments
Assignees
Labels
duplicate This issue or pull request already exists

Comments

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented May 13, 2021

MyPy says that jax.interpreters.xla.DeviceArray has no attribute "shape", is not indexable, etc. It seems that it points to DeviceArrayBase—not DeviceArray, which has the annotations.

I know that type annotations are a work-in-progress. This is just a placeholder issue :)

@NeilGirdhar NeilGirdhar added the bug Something isn't working label May 13, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented May 13, 2021

Thanks for the report! This is a known issue, tracked all the way back since #943.

It turns out it's quite difficult to satisfy mypy when it comes to annotating JAX arrays. If you have any ideas about how to address this, we're open to contributions.

In the meantime, you might follow the strategy that we use frequently throughout the package: create the alias Array = Any and then annotate arrays with that... mypy will never complain again 😁

@jakevdp jakevdp added duplicate This issue or pull request already exists and removed bug Something isn't working labels May 13, 2021
@jakevdp jakevdp self-assigned this May 13, 2021
@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented May 13, 2021

@jakevdp Oh! Sorry! Since it was a new error, I wasn't sure if it was another instance of #943. Admittedly, I don't know enough about the problem, but I hoped it might be as easy as pointing jax.interpreters.xla.DeviceArray to tensorflow.DeviceArray, or maybe moving the annotations from tensorflow.DeviceArray to tensorflow.DeviceArrayBase.

Anyway, in my particular case, I do need to use a direct pointer to DeviceArray since I'm registering a function for single dispatch.

Feel free to close this as a duplicate if it is one :)

@jakevdp
Copy link
Collaborator

jakevdp commented May 13, 2021

I would be careful registering a function to dispatch on DeviceArray: in particular, if you use any JAX transforms (jit, grad, vmap, etc.) they will pass Tracer objects in place of DeviceArray objects, and so your dispatch mechanism will probably not work as intended.

Additionally, you should be aware that there are multiple flavors of DeviceArray used by JAX, so you should ensure you're handling all relevant cases.

These are some of the same concerns that make satisfying mypy nontrivial.

@NeilGirdhar
Copy link
Contributor Author

they will pass Tracer objects in place of DeviceArray objects, and so your dispatch mechanism will probably not work as intended.

Yeah, thanks. I dispatch on tracers too. :)

Additionally, you should be aware that there are multiple flavors of DeviceArray used by JAX, so you should ensure you're handling all relevant cases.

These are some of the same concerns that make satisfying mypy nontrivial.

Fair enough. All I really want is to print out the elements in the arrays that I can print out, and print shapes for the other arrays. I've been tweaking it over the last year, but it seems to be working now.

@jakevdp
Copy link
Collaborator

jakevdp commented May 13, 2021

I see - that makes sense! In that case it sounds like your best bet is to tell mypy to ignore these annotations, because indeed jax.interpreters.xla.DeviceArray, which is a base class, does not have all the attributes of an actual instantiated DeviceArray.

Alternatively, you could dispatch on each of the DeviceArray implementations (I believe they are xla._DeviceArray and xla._CppDeviceArray) which do have all the expected attributes.

@NeilGirdhar
Copy link
Contributor Author

Okay, I think I understand, but if jax.interpreters.xla.DeviceArray is an abstract base class, wouldn't it make more sense for it to have the abstract methods defined in it? You could just move the method annotations from tensorflow.DeviceArray to tensorflow.DeviceArrayBase, which could have stubs that raise NotImplementedError?

(I'm not really fond of dispatching on private classes even if they're concrete.)

@jakevdp
Copy link
Collaborator

jakevdp commented May 13, 2021

I think there's a distinction here between "DeviceArray as xla operand" and "DeviceArray as duck type for numpy arrays". The former should not have numpy-like attributes, the latter should.

This might argue that we need another base class in the hierarchy that defines the interface to the two DeviceArray implementations; I think the reason we haven't done that is that the two implementations are temporary; somewhere there's a TODO to unify them again.

Would it serve your purposes to locally disable mypy type checking, and only use the DeviceArray annotation for dispatch?

@NeilGirdhar
Copy link
Contributor Author

Okay, that makes sense. Thanks for explaining it. I'll just block the MyPy errors!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
duplicate This issue or pull request already exists
Projects
None yet
Development

No branches or pull requests

2 participants