Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add experimental __array_module__ method #4076

Merged
merged 5 commits into from
Aug 18, 2020
Merged

Conversation

shoyer
Copy link
Collaborator

@shoyer shoyer commented Aug 15, 2020

xref #1565

__array_module__ (see NEP 37) is an experimental alternative to __array_function__ and __array_ufunc__ for "duck array" compatibility with NumPy that promises to be much less invasive.

Example usage, for writing a generic version of np.stack() on top of concatenate, asarray, shape attributes and array indexing:

import numpy_dispatch

def duckarray_stack(arrays):
    """This "stack" function should work with any array library, including JAX."""
    npx = numpy_dispatch.get_array_module(*arrays)  # returns jax.numpy for JAX arrays
    arrays = [npx.asarray(arr) for arr in arrays]
    shapes = {arr.shape for arr in arrays}
    if len(shapes) != 1:
        raise ValueError('all input arrays must have the same shape')
    expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays]
    return npx.concatenate(expanded_arrays, axis=0)

For now, you need to use numpy_dispatch.get_array_module() to try this feature (with https://github.com/seberg/numpy-dispatch) but in the long term the hope is to merge this into NumPy proper as np.get_array_module().

My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that:

  1. It's not invasive -- the implementation is small and self-contained.
  2. No backwards compatibility issues. Unlike __array_function__ and __array_ufunc__, __array_module__ will always require an explicit opt-in by libraries that use it by calling get_array_module().
  3. Other NumPy developers want evidence that this is actually feasible.
  4. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier.

Note: this PR does add numpy-dispatch as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch appears to only add about 7 extra seconds of build time.

xref jax-ml#1565

`__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html))
is an experimental alternative to `__array_function__` and `__array_ufunc__`
for "duck array" compatibility with NumPy that promises to be much less
invasive.

Example usage:

```python
import numpy as np

def duckarray_stack(arrays):
    """This "stack" function should work with any array library, including JAX."""
    npx = np.get_array_module(*arrays)
    arrays = [npx.asarray(arr) for arr in arrays]
    shapes = {arr.shape for arr in arrays}
    if len(shapes) != 1:
        raise ValueError('all input arrays must have the same shape')
    expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays]
    return npx.concatenate(expanded_arrays, axis=0)
```

Support for this protocol has *not* yet been implemented in NumPy, but it can
be tested with https://github.com/seberg/numpy-dispatch.

My reasoning for merging it into JAX (on an experimental basis with no
guarantees, of course) is that:

1. It's not invasive -- the implementation is small and self-contained.
2. No backwards compatibility issues. Unlike `__array_function__` and
   `__array_ufunc__`, `__array_module__` will always require an explicit
   opt-in by libraries that use it by calling `get_array_module()`.
2. Other NumPy developers
   [want evidence](numpy/numpy#16935 (comment))
   that this is actually feasible.
3. Scikit-Learn developers like @thomasjpfan are interested in exploring
   supporting scikit-learn on top of NumPy-like libraries like JAX, and
   experimental support for this protocol will make that easier.

Note: this PR does add `numpy-dispatch` as a optional testing requirement in
order to verify that this works. If desired, we could remove this from CI, but
installing numpy-dispatch (and its build requirement Cython) appears to only
add a few seconds of build time.
@google-cla google-cla bot added the cla: yes label Aug 15, 2020
@@ -4574,6 +4575,21 @@ def _operator_round(number, ndigits=None):
setattr(DeviceArray, "nbytes", property(_nbytes))


# Experimental support for NumPy's module dispatch with NEP-37.
# Currently requires https://github.com/seberg/numpy-dispatch
_JAX_ARRAY_TYPES = (DeviceArray, core.Tracer)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure it this list should include UnshapedArray or one of JAX's other abstract array types.

Their presence in the set of types checked by jax.numpy.ndarray suggests yes, but on the other hand I don't think I've ever seen them in user supplied functions transformed with JAX?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like limiting the set for now. If it becomes clear later that UnshapedArray or others are necessary, they are easy enough to add.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! I'm in favor of provisionally adding this to help motivate this kind of feature in the broader Python array ecosystem.

On the JAX side, one detail we may need to think about is this from the Array module contract:

Unimplemented functionality should simply be omitted (e.g., accessing an unimplemented function should raise AttributeError).

This is counter to JAX's current strategy, which is to define unimplemented functions but raise NotImplementedError when they are called. I'm not sure how important that detail is. I think it's probably OK to push forward as-is and address that if it becomes an issue.

@shoyer
Copy link
Collaborator Author

shoyer commented Aug 17, 2020

Unimplemented functionality should simply be omitted (e.g., accessing an unimplemented function should raise AttributeError).

This is counter to JAX's current strategy, which is to define unimplemented functions but raise NotImplementedError when they are called. I'm not sure how important that detail is. I think it's probably OK to push forward as-is and address that if it becomes an issue.

I don't think this detail of NEP-37 (the precise error message) matters too much.

Possibly NEP-37 should be updated, since it can indeed be helpful to distinguish between a spelling error vs. an error from a function that just hasn't been implemented. Perhaps the cleanest way to do this would be module level __getattr__ from Python 3.7+ (https://www.python.org/dev/peps/pep-0562/), e.g.,

# in jax/numpy/__init__.py
def __getattr__(name):
  import numpy
  if hasattr(numpy, name):
    raise AttributeError(f'numpy.{name} has not been implemented yet in jax.numpy')
  else:
    raise AttributeError(f"module 'jax.numpy' has no attribute {name}")

@jakevdp jakevdp merged commit decd760 into jax-ml:master Aug 18, 2020
@shoyer shoyer deleted the array-module branch August 18, 2020 17:32
@thomasjpfan
Copy link

@shoyer Thank you for working on this! I am looking forward to trying this out.

tensorflow-copybara pushed a commit to tensorflow/tensorflow that referenced this pull request Aug 24, 2020
__array_module__ is an experimental protocol for "duck array" compatibility that
for indicating how to find a "numpy compatible" module. The hope is to make it
easier to write generic code that works across a range of array libraries.

A full example, which should work equally work for TF-NumPy as for JAX, can be
found at jax-ml/jax#4076. More examples, and motivation for
this protocol can be found at https://numpy.org/neps/nep-0037-array-module.html.

This design has not yet been finalized in NumPy, so at present it requires using
the experimental numpy_dispatch module: https://github.com/seberg/numpy-dispatch.

Unlike NumPy's __array_ufunc__ and __array_function__ protocols, __array_module__
by design has no backwards compatibility consequences. The protocol only controls
the behavior of numpy_dispatch.get_array_module() or numpy.get_array_module() --
it does not change any existing NumPy functions.

PiperOrigin-RevId: 328181308
Change-Id: I98b648a59709ed3bc295aaf97f471a1ff7ae1f05
@juliuskunze juliuskunze mentioned this pull request Nov 30, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants