Skip to content

Commit

Permalink
Add experimental __array_module__ method
Browse files Browse the repository at this point in the history
xref jax-ml#1565

`__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html))
is an experimental alternative to `__array_function__` and `__array_ufunc__`
for "duck array" compatibility with NumPy that promises to be much less
invasive.

Example usage:

```python
import numpy as np

def duckarray_stack(arrays):
    """This "stack" function should work with any array library, including JAX."""
    npx = np.get_array_module(*arrays)
    arrays = [npx.asarray(arr) for arr in arrays]
    shapes = {arr.shape for arr in arrays}
    if len(shapes) != 1:
        raise ValueError('all input arrays must have the same shape')
    expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays]
    return npx.concatenate(expanded_arrays, axis=0)
```

Support for this protocol has *not* yet been implemented in NumPy, but it can
be tested with https://github.com/seberg/numpy-dispatch.

My reasoning for merging it into JAX (on an experimental basis with no
guarantees, of course) is that:

1. It's not invasive -- the implementation is small and self-contained.
2. No backwards compatibility issues. Unlike `__array_function__` and
   `__array_ufunc__`, `__array_module__` will always require an explicit
   opt-in by libraries that use it by calling `get_array_module()`.
2. Other NumPy developers
   [want evidence](numpy/numpy#16935 (comment))
   that this is actually feasible.
3. Scikit-Learn developers like @thomasjpfan are interested in exploring
   supporting scikit-learn on top of NumPy-like libraries like JAX, and
   experimental support for this protocol will make that easier.

Note: this PR does add `numpy-dispatch` as a optional testing requirement in
order to verify that this works. If desired, we could remove this from CI, but
installing numpy-dispatch (and its build requirement Cython) appears to only
add a few seconds of build time.
  • Loading branch information
shoyer committed Aug 15, 2020
1 parent 1316562 commit 36d9949
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 1 deletion.
5 changes: 4 additions & 1 deletion .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ jobs:
os: ubuntu-latest
enable-x64: 1
enable-omnistaging: 0
package-overrides: "none"
# Test experimental NumPy dispatch
# TODO(shoyer): remove cython after
# https://github.com/seberg/numpy-dispatch/pull/5 is merged
package-overrides: "cython git+https://github.com/seberg/numpy-dispatch.git"
num_generated_cases: 25
- python-version: 3.6
os: ubuntu-latest
Expand Down
6 changes: 6 additions & 0 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,9 @@ def __len__(self):
def aval(self):
raise NotImplementedError("must override")

# Python looks up special methods only on classes, not instances. This means
# these methods needs to be defined explicitly rather than relying on
# __getattr__ (short of using a metaclass).
def __neg__(self): return self.aval._neg(self)
def __pos__(self): return self.aval._pos(self)
def __eq__(self, other): return self.aval._eq(self, other)
Expand Down Expand Up @@ -528,6 +531,9 @@ def __complex__(self):
def __setitem__(self, idx, val):
raise TypeError("JAX 'Tracer' objects do not support item assignment")

# NumPy also only looks up special methods on classes.
def __array_module__(self, types): return self.aval._array_module(self, types)

def __getattr__(self, name):
# if the aval property raises an AttributeError, gets caught here
assert skip_checks or name != "aval"
Expand Down
16 changes: 16 additions & 0 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import numpy as np
import opt_einsum

import jax
from jax import jit, custom_jvp
from .vectorize import vectorize
from ._util import _wraps
Expand Down Expand Up @@ -4574,6 +4575,21 @@ def _operator_round(number, ndigits=None):
setattr(DeviceArray, "nbytes", property(_nbytes))


# Experimental support for NumPy's module dispatch with NEP-37.
# Currently requires https://github.com/seberg/numpy-dispatch
_JAX_ARRAY_TYPES = (UnshapedArray, DeviceArray, core.Tracer)
_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,)

def __array_module__(self, types):
if builtins.all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types):
return jax.numpy
else:
return NotImplemented

setattr(ShapedArray, "_array_module", staticmethod(__array_module__))
setattr(DeviceArray, "__array_module__", __array_module__)


# Extra methods that are handy
setattr(ShapedArray, "broadcast", core.aval_method(lax.broadcast))
setattr(ShapedArray, "broadcast_in_dim", core.aval_method(lax.broadcast_in_dim))
Expand Down
25 changes: 25 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from absl.testing import parameterized

import numpy as np
try:
import numpy_dispatch
except ImportError:
numpy_dispatch = None

import jax
import jax.ops
Expand Down Expand Up @@ -585,6 +589,27 @@ def testBinaryOperatorDefers(self, op_name, rng_factory, dtype):
with self.assertRaises(TypeError):
op(arg, other)

def testArrayModule(self):
if numpy_dispatch is None:
raise SkipTest('requires https://github.com/seberg/numpy-dispatch')

jnp_array = jnp.array(1.0)
np_array = np.array(1.0)

with numpy_dispatch.ensure_dispatching():
module = numpy_dispatch.get_array_module(jnp_array)
self.assertIs(module, jnp)

module = numpy_dispatch.get_array_module(jnp_array, np_array)
self.assertIs(module, jnp)

def f(x):
module = numpy_dispatch.get_array_module(x)
self.assertIs(module, jnp)
return x
jax.jit(f)(jnp_array)
jax.grad(f)(jnp_array)

@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
Expand Down

0 comments on commit 36d9949

Please sign in to comment.