Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytype resolves jnp.ndarray to Any #3784

Closed
aslanides opened this issue Jul 17, 2020 · 4 comments
Closed

Pytype resolves jnp.ndarray to Any #3784

aslanides opened this issue Jul 17, 2020 · 4 comments
Labels
enhancement New feature or request

Comments

@aslanides
Copy link
Contributor

This is simple enough to reproduce:

# foo.py
import jax.numpy as jnp

def foo(x: jnp.ndarray):
  return x + 1

foo('bar')
$ pytype foo.py
# Passes type checking.
@jakevdp
Copy link
Collaborator

jakevdp commented Jul 17, 2020

Relevant discussion here: #943

@hawkinsp hawkinsp added the enhancement New feature or request label Jul 20, 2020
@hamzamerzic
Copy link
Contributor

Hi, are there any updates on this? How difficult would it be do get the correct behavior? And in anticipation of this, is there a type that we could define that would satisfy the check? Would the following be enough?

JAXArray = Union[
    jax.interpreters.xla.DeviceArray, 
    jax.interpreters.pxla.ShardedDeviceArray, 
    jax.interpreters.batching.BatchTracer,
]

@YouJiacheng
Copy link
Contributor

Now mypy can correctly check this snippet.

@hawkinsp
Copy link
Collaborator

There's good news and bad news on this. pytype, outside Google, does not support type checking using the types of external dependencies unless they are in the Python typeshed types. I don't think it would make sense for JAX to do that, not least because JAX changes constantly!

So if you are a downstream consumer of JAX, you cannot get pytype to type the jax API, until pytype supports external dependencies. There's nothing JAX-specific about this!

(Inside Google, pytype supports type checking of external dependencies, and indeed, in that case, this bug is already fixed.)

Since there's no action for us to take on the JAX side, closing. This will work great if and when pytype supports this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants