From ca4f7f79649ba9823352809bff338a366bd913e5 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 5 Feb 2021 20:30:14 -0800 Subject: [PATCH] add check for __jax_array__ method before error Before raising an error on an unrecognized type, first check if the object defines a __jax_array__ method. If it does, call it! This provides a way for custom types to be auto-converted to JAX-compatible types. Implementing this method is not sufficient for a type to be duck-typed enough for use with jax.numpy. But it may be necessary. That is, someone trying to add a duck-typed array to be used with JAX identified a need for __jax_array__ or similar. The user would still need to add lots of other properties and methods, like dtype and shape attributes. revives #4725 after it was rolled back. fixes #5356. --- jax/_src/lax/lax.py | 2 ++ jax/_src/numpy/lax_numpy.py | 6 ++++-- jax/core.py | 2 ++ jax/interpreters/xla.py | 4 ++++ tests/api_test.py | 26 ++++++++++++++++++++++++++ 5 files changed, 38 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 0a823ab3229d..69f94058ae69 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -422,6 +422,8 @@ def convert_element_type(operand: Array, new_dtype: DType) -> Array: An array with the same shape as `operand`, cast elementwise to `new_dtype`. """ new_dtype = dtypes.canonicalize_dtype(new_dtype) + if hasattr(operand, '__jax_array__'): + operand = operand.__jax_array__() # Avoids dropping precision by casting Python scalars to the default Jax # type. If we passed a Python scalar directly to the bind call below, it is # cast to the default type as part of the calling convention. diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index d7b79bdc35a6..c45a8538ec5d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -295,9 +295,11 @@ def _result_dtype(op, *args): return _dtype(op(*args)) -def _arraylike(x): return isinstance(x, ndarray) or isscalar(x) +def _arraylike(x): + return isinstance(x, ndarray) or isscalar(x) or hasattr(x, '__jax_array__') + def _check_arraylike(fun_name, *args): - """Check if all args fit JAX's definition of arraylike (ndarray or scalar).""" + """Check if all args fit JAX's definition of arraylike.""" assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}" if _any(not _arraylike(arg) for arg in args): pos, arg = next((i, arg) for i, arg in enumerate(args) diff --git a/jax/core.py b/jax/core.py index 419c8e4d1712..f01b64807029 100644 --- a/jax/core.py +++ b/jax/core.py @@ -910,6 +910,8 @@ def concrete_aval(x): for typ in type(x).mro(): handler = pytype_aval_mappings.get(typ) if handler: return handler(x) + if hasattr(x, '__jax_array__'): + return concrete_aval(x.__jax_array__()) raise TypeError(f"{type(x)} is not a valid JAX type") diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 0a05c5ceaa4a..8f0027e56169 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -154,6 +154,8 @@ def canonicalize_dtype(x): for typ in typ.mro(): handler = canonicalize_dtype_handlers.get(typ) if handler: return handler(x) + if hasattr(x, '__jax_array__'): + return canonicalize_dtype(x.__jax_array__()) raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}") def _canonicalize_ndarray_dtype(x): @@ -176,6 +178,8 @@ def abstractify(x) -> core.AbstractValue: for typ in typ.mro(): aval_fn = pytype_aval_mappings.get(typ) if aval_fn: return aval_fn(x) + if hasattr(x, '__jax_array__'): + return abstractify(x.__jax_array__()) raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type") def _make_abstract_python_scalar(typ, _): diff --git a/tests/api_test.py b/tests/api_test.py index 1c6d680a9136..5fa481430195 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2277,6 +2277,32 @@ def test_default_backend(self): first_local_device = api.local_devices()[0] self.assertEqual(first_local_device.platform, api.default_backend()) + def test_dunder_jax_array(self): + # https://github.com/google/jax/pull/4725 + + class AlexArray: + def __init__(self, jax_val): + self.jax_val = jax_val + def __jax_array__(self): + return self.jax_val + dtype = property(lambda self: self.jax_val.dtype) + shape = property(lambda self: self.jax_val.shape) + + x = AlexArray(jnp.array([1., 2., 3.])) + y = jnp.sin(x) + self.assertAllClose(y, jnp.sin(jnp.array([1., 2., 3.]))) + y = api.grad(api.jit(lambda x: jnp.sin(x).sum()))(x) + self.assertAllClose(y, jnp.cos(jnp.array([1., 2., 3.]))) + + x = AlexArray(jnp.array([[1., 2., 3.]])) + y = api.pmap(jnp.sin)(x) + self.assertAllClose(y, jnp.sin(jnp.array([[1., 2., 3.]]))) + + x = jnp.array(1) + a = AlexArray(x) + for f in [jnp.isscalar, jnp.size, jnp.shape, jnp.dtype]: + self.assertEqual(f(x), f(a)) + class RematTest(jtu.JaxTestCase):