Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One array type to rule them all! #943

Closed
samuela opened this issue Jun 27, 2019 · 14 comments
Closed

One array type to rule them all! #943

samuela opened this issue Jun 27, 2019 · 14 comments
Assignees

Comments

@samuela
Copy link
Contributor

samuela commented Jun 27, 2019

I use Python 3's typing features as much as possible. Unfortunately jax's value hierarchy makes this a little bit challenging. Consider the following snippet,

from typing import NamedTuple

import jax.numpy as jp
from jax import lax, random

class Normal(NamedTuple):
  loc: ArrayType
  scale: ArrayType

  def sample(self, rng, sample_shape=()) -> ArrayType:
    batch_shape = lax.broadcast_shapes(self.loc.shape, self.scale.shape)
    return self.loc + self.scale * random.normal(
        rng, shape=sample_shape + batch_shape)

I'd like to be able to fill in the mystery ArrayTypes with something like a made-up "jp.Array", but AFAICT from the array class hierarchy, there is no such type that really fits. At first glance jp.DeviceArray looks like an eligible candidate, but then there is also ConcreteArray, ShapedArray, and UnshapedArray. I'm not really sure what the differences are between them but some of them derive from jax.core.AbstractValue, while DeviceArray does not... If there are other cases then I'd certainly like to avoid limiting my type signatures to only operating on arrays that live on-device. To make matters more confusing there also seems to be _FilledConstant and DeviceConstant:

In [10]: jp.ones((2, 3))
Out[10]: 
_FilledConstant([[1., 1., 1.],
                 [1., 1., 1.]], dtype=float32)

All in all, it's not clear to me how each of these types play together and how (if possible) to unify them. What's the appropriate type to be used here? And if it does not yet exist, could we create such a type?

@jekbradbury
Copy link
Contributor

I believe all of these types have isinstance(x, np.ndarray) = True, but maybe that's not enough for typing?

@samuela
Copy link
Contributor Author

samuela commented Jun 28, 2019

@jekbradbury Hmm, that doesn't quite work for me. Here are the results from running pyre:

myfile.py:40:11 Undefined attribute [16]: `jp.lax_numpy.ndarray` has no attribute `__add__`.

@jekbradbury
Copy link
Contributor

Oof, it can't infer through JAX's monkeypatching of numpy methods. I'm not sure what to do about that (IIRC the monkeypatching exists to avoid a dependency issue: core JAX array types need to be able to dispatch their operators to jax.numpy math functions (rather than the jax.lax math functions that have different broadcasting semantics), but the core can't depend on jax.numpy because jax.numpy needs to depend on jax.lax and jax.lax needs to depend on the core).

@samuela
Copy link
Contributor Author

samuela commented Jun 28, 2019

Ooh boy, I didn't realize the rabbit hole went that deep! Is there any way to create an abstract "fake" type to capture all of the things expected in core and then force things to occupy that fake type?

@jekbradbury
Copy link
Contributor

I bet there is -- or we could maybe add the expected operators to core arrays with empty implementations, then override them in the monkeypatch (rather than adding them where none existed).

@samuela
Copy link
Contributor Author

samuela commented Jun 28, 2019

@jekbradbury Yeah, either of these would be great for my use case.

@samuela
Copy link
Contributor Author

samuela commented Jun 28, 2019

@jekbradbury
Copy link
Contributor

Probably, though I’m not super familiar with that part of the code. (CC @mattjj)

@mattjj
Copy link
Collaborator

mattjj commented Jul 4, 2019

Given the error message, that'd also be my best guess at where to add stubs. But I'm not familiar with any of this Python type checking machinery.

JAX relies on a lot of overloading of user code; we have the ndarray (effective) subclasses DeviceArray and ShardedDeviceArray to represent concrete values, but those aren't actually what flow through user code when transforming it. What flows through user code are instances of the Tracer class, and each Tracer instance carries along with it an instance of (a subclass of) the AbstractValue class to which it delegates (i.e. forwards) method calls. For arrays, that abstract value is currently a always a subclass of ShapedArray, though we will likely revise the array abstract value lattice in the future (e.g. adding more intermediate levels between unshaped and shaped).

That's a lot of overloading and indirection for a type checker to reason about...

If you figure anything out here it would be super beneficial to a lot of JAX users, and I'm eager to answer questions about it as best I can, but the core team doesn't have cycles to look into supporting type-checking systems for user code at the moment.

What do you think?

@samuela
Copy link
Contributor Author

samuela commented Jul 8, 2019

FWIW my solution here has been to begin a pyi module that is symlinked for both numpy and jax.numpy:

Shape = Tuple[int, ...]

class ndarray:
  @property
  def shape(self) -> Shape:
    ...

  def __getitem__(self, key) -> Any:
    ...

  def __add__(self, other) -> ndarray:
    ...

# and so on...

That certainly doesn't solve the problem for jax.lax, etc. but it's a start. I guess it also means that one way to approach this would be to have a jax-types package that contains a bunch of pyi definitions, instead of having to start including type annotations in jax itself if that is not desirable.

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Mar 29, 2020

Have you looked at https://github.com/numpy/numpy-stubs and https://github.com/ramonhagenaars/nptyping ?

@hamzamerzic
Copy link
Contributor

Hi, are there any updates on this? How difficult would it be do get the correct behavior? And in anticipation of this, is there a type that we could define that would satisfy the check? Would the following be enough?

JAXArray = Union[
    jax.interpreters.xla.DeviceArray, 
    jax.interpreters.pxla.ShardedDeviceArray, 
    jax.interpreters.batching.BatchTracer,
]

@wookayin
Copy link
Contributor

wookayin commented Sep 3, 2022

It looks like this is going to be addressed in #11859 --- "Jax Enhancement Proposal (JEP)" (still in progress): https://jax--11859.org.readthedocs.build/en/11859/jep/12049-type-annotations.html

Or one can use jaxtyping: see https://github.com/google/jaxtyping/blob/main/API.md (UPDATE: still premature when it comes to pytype issues, so it doesn't address the false positive errors like #6743, etc.)

@hawkinsp
Copy link
Collaborator

hawkinsp commented Mar 3, 2023

This is fixed these days! (jax.Array and jax.typing.ArrayLike are the types you want.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants