Skip to content

Commit

Permalink
add check for __jax_array__ method before error
Browse files Browse the repository at this point in the history
Before raising an error on an unrecognized type, first check if the
object defines a __jax_array__ method. If it does, call it!

This provides a way for custom types to be auto-converted to
JAX-compatible types.

Implementing this method is not sufficient for a type to be duck-typed
enough for use with jax.numpy. But it may be necessary. That is, someone
trying to add a duck-typed array to be used with JAX identified a need
for __jax_array__ or similar. The user would still need to add lots of
other properties and methods, like dtype and shape attributes.

revives #4725 after it was rolled back. fixes #5356.
  • Loading branch information
mattjj committed Feb 6, 2021
1 parent 3c87a36 commit ca4f7f7
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 2 deletions.
2 changes: 2 additions & 0 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ def convert_element_type(operand: Array, new_dtype: DType) -> Array:
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
"""
new_dtype = dtypes.canonicalize_dtype(new_dtype)
if hasattr(operand, '__jax_array__'):
operand = operand.__jax_array__()
# Avoids dropping precision by casting Python scalars to the default Jax
# type. If we passed a Python scalar directly to the bind call below, it is
# cast to the default type as part of the calling convention.
Expand Down
6 changes: 4 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,11 @@ def _result_dtype(op, *args):
return _dtype(op(*args))


def _arraylike(x): return isinstance(x, ndarray) or isscalar(x)
def _arraylike(x):
return isinstance(x, ndarray) or isscalar(x) or hasattr(x, '__jax_array__')

def _check_arraylike(fun_name, *args):
"""Check if all args fit JAX's definition of arraylike (ndarray or scalar)."""
"""Check if all args fit JAX's definition of arraylike."""
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
if _any(not _arraylike(arg) for arg in args):
pos, arg = next((i, arg) for i, arg in enumerate(args)
Expand Down
2 changes: 2 additions & 0 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,8 @@ def concrete_aval(x):
for typ in type(x).mro():
handler = pytype_aval_mappings.get(typ)
if handler: return handler(x)
if hasattr(x, '__jax_array__'):
return concrete_aval(x.__jax_array__())
raise TypeError(f"{type(x)} is not a valid JAX type")


Expand Down
4 changes: 4 additions & 0 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def canonicalize_dtype(x):
for typ in typ.mro():
handler = canonicalize_dtype_handlers.get(typ)
if handler: return handler(x)
if hasattr(x, '__jax_array__'):
return canonicalize_dtype(x.__jax_array__())
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")

def _canonicalize_ndarray_dtype(x):
Expand All @@ -176,6 +178,8 @@ def abstractify(x) -> core.AbstractValue:
for typ in typ.mro():
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
if hasattr(x, '__jax_array__'):
return abstractify(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")

def _make_abstract_python_scalar(typ, _):
Expand Down
26 changes: 26 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2277,6 +2277,32 @@ def test_default_backend(self):
first_local_device = api.local_devices()[0]
self.assertEqual(first_local_device.platform, api.default_backend())

def test_dunder_jax_array(self):
# https://github.com/google/jax/pull/4725

class AlexArray:
def __init__(self, jax_val):
self.jax_val = jax_val
def __jax_array__(self):
return self.jax_val
dtype = property(lambda self: self.jax_val.dtype)
shape = property(lambda self: self.jax_val.shape)

x = AlexArray(jnp.array([1., 2., 3.]))
y = jnp.sin(x)
self.assertAllClose(y, jnp.sin(jnp.array([1., 2., 3.])))
y = api.grad(api.jit(lambda x: jnp.sin(x).sum()))(x)
self.assertAllClose(y, jnp.cos(jnp.array([1., 2., 3.])))

x = AlexArray(jnp.array([[1., 2., 3.]]))
y = api.pmap(jnp.sin)(x)
self.assertAllClose(y, jnp.sin(jnp.array([[1., 2., 3.]])))

x = jnp.array(1)
a = AlexArray(x)
for f in [jnp.isscalar, jnp.size, jnp.shape, jnp.dtype]:
self.assertEqual(f(x), f(a))


class RematTest(jtu.JaxTestCase):

Expand Down

0 comments on commit ca4f7f7

Please sign in to comment.