From decd760020039f159e74c17d10968b9a52a12ee0 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 18 Aug 2020 09:40:57 -0700 Subject: [PATCH] Add experimental __array_module__ method (#4076) * Add experimental __array_module__ method xref https://github.com/google/jax/issues/1565 `__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html)) is an experimental alternative to `__array_function__` and `__array_ufunc__` for "duck array" compatibility with NumPy that promises to be much less invasive. Example usage: ```python import numpy as np def duckarray_stack(arrays): """This "stack" function should work with any array library, including JAX.""" npx = np.get_array_module(*arrays) arrays = [npx.asarray(arr) for arr in arrays] shapes = {arr.shape for arr in arrays} if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays] return npx.concatenate(expanded_arrays, axis=0) ``` Support for this protocol has *not* yet been implemented in NumPy, but it can be tested with https://github.com/seberg/numpy-dispatch. My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that: 1. It's not invasive -- the implementation is small and self-contained. 2. No backwards compatibility issues. Unlike `__array_function__` and `__array_ufunc__`, `__array_module__` will always require an explicit opt-in by libraries that use it by calling `get_array_module()`. 2. Other NumPy developers [want evidence](https://github.com/numpy/numpy/pull/16935#issuecomment-673951287) that this is actually feasible. 3. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier. Note: this PR does add `numpy-dispatch` as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch (and its build requirement Cython) appears to only add a few seconds of build time. * don't explicitly list cython * remove UnshpaedArray from _JAX_ARRAY_TYPES * Remove incorrect note about metaclasses * remove unnecessary numpy_dispatch.ensure_dispatching() --- .github/workflows/ci-build.yaml | 3 ++- jax/core.py | 6 ++++++ jax/numpy/lax_numpy.py | 16 ++++++++++++++++ tests/lax_numpy_test.py | 24 ++++++++++++++++++++++++ 4 files changed, 48 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 833604b9c537..78f44196b397 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -60,7 +60,8 @@ jobs: os: ubuntu-latest enable-x64: 1 enable-omnistaging: 0 - package-overrides: "none" + # Test experimental NumPy dispatch + package-overrides: "git+https://github.com/seberg/numpy-dispatch.git" num_generated_cases: 25 - python-version: 3.6 os: ubuntu-latest diff --git a/jax/core.py b/jax/core.py index 8a5576b916cb..2eed75934d6e 100644 --- a/jax/core.py +++ b/jax/core.py @@ -469,6 +469,9 @@ def __len__(self): def aval(self): raise NotImplementedError("must override") + # Python looks up special methods only on classes, not instances. This means + # these methods needs to be defined explicitly rather than relying on + # __getattr__. def __neg__(self): return self.aval._neg(self) def __pos__(self): return self.aval._pos(self) def __eq__(self, other): return self.aval._eq(self, other) @@ -528,6 +531,9 @@ def __complex__(self): def __setitem__(self, idx, val): raise TypeError("JAX 'Tracer' objects do not support item assignment") + # NumPy also only looks up special methods on classes. + def __array_module__(self, types): return self.aval._array_module(self, types) + def __getattr__(self, name): # if the aval property raises an AttributeError, gets caught here assert skip_checks or name != "aval" diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 6610a3ece318..92e247db7c39 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -36,6 +36,7 @@ import numpy as np import opt_einsum +import jax from jax import jit, custom_jvp from .vectorize import vectorize from ._util import _wraps @@ -4574,6 +4575,21 @@ def _operator_round(number, ndigits=None): setattr(DeviceArray, "nbytes", property(_nbytes)) +# Experimental support for NumPy's module dispatch with NEP-37. +# Currently requires https://github.com/seberg/numpy-dispatch +_JAX_ARRAY_TYPES = (DeviceArray, core.Tracer) +_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,) + +def __array_module__(self, types): + if builtins.all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types): + return jax.numpy + else: + return NotImplemented + +setattr(ShapedArray, "_array_module", staticmethod(__array_module__)) +setattr(DeviceArray, "__array_module__", __array_module__) + + # Extra methods that are handy setattr(ShapedArray, "broadcast", core.aval_method(lax.broadcast)) setattr(ShapedArray, "broadcast_in_dim", core.aval_method(lax.broadcast_in_dim)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 44b7b2f720fc..69d6d94dfee7 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -28,6 +28,10 @@ from absl.testing import parameterized import numpy as np +try: + import numpy_dispatch +except ImportError: + numpy_dispatch = None import jax import jax.ops @@ -585,6 +589,26 @@ def testBinaryOperatorDefers(self, op_name, rng_factory, dtype): with self.assertRaises(TypeError): op(arg, other) + def testArrayModule(self): + if numpy_dispatch is None: + raise SkipTest('requires https://github.com/seberg/numpy-dispatch') + + jnp_array = jnp.array(1.0) + np_array = np.array(1.0) + + module = numpy_dispatch.get_array_module(jnp_array) + self.assertIs(module, jnp) + + module = numpy_dispatch.get_array_module(jnp_array, np_array) + self.assertIs(module, jnp) + + def f(x): + module = numpy_dispatch.get_array_module(x) + self.assertIs(module, jnp) + return x + jax.jit(f)(jnp_array) + jax.grad(f)(jnp_array) + @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix(