diff --git a/jax/abstract_arrays.py b/jax/abstract_arrays.py index d920c3e1111c..44710e8945dd 100644 --- a/jax/abstract_arrays.py +++ b/jax/abstract_arrays.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial + import numpy as onp from . import ad_util @@ -57,16 +59,16 @@ def zeros_like_shaped_array(aval): core.literalable_types.update(array_types) -def _zeros_like_python_scalar(x): - return onp.array(0, dtypes.python_scalar_dtypes[type(x)]) +def _zeros_like_python_scalar(t, x): + return onp.array(0, dtypes.python_scalar_dtypes[t]) -def _make_concrete_python_scalar(x): +def _make_concrete_python_scalar(t, x): return ConcreteArray( - onp.array(x, dtype=dtypes.python_scalar_dtypes[type(x)]), + onp.array(x, dtype=dtypes.python_scalar_dtypes[t]), weak_type=True) for t in dtypes.python_scalar_dtypes.keys(): - core.pytype_aval_mappings[t] = _make_concrete_python_scalar - ad_util.jaxval_zeros_likers[t] = _zeros_like_python_scalar + core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t) + ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t) core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) diff --git a/jax/core.py b/jax/core.py index 07983779bb99..ae266911d19d 100644 --- a/jax/core.py +++ b/jax/core.py @@ -698,10 +698,10 @@ def valid_jaxtype(x): def concrete_aval(x): - try: - return pytype_aval_mappings[type(x)](x) - except KeyError as err: - raise TypeError("{} is not a valid Jax type".format(type(x))) from err + for typ in type(x).mro(): + handler = pytype_aval_mappings.get(typ) + if handler: return handler(x) + raise TypeError(f"{type(x)} is not a valid Jax type") def get_aval(x): diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index f20c91ccaa8b..3ccebbed236c 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -168,9 +168,7 @@ class AnEnum(enum.IntEnum): A = 42 B = 101 np.testing.assert_equal(np.array(42), np.array(AnEnum.A)) - with core.skipping_checks(): - # Passing AnEnum.A to jnp.array fails the type check in bind - np.testing.assert_equal(jnp.array(42), jnp.array(AnEnum.A)) + np.testing.assert_equal(jnp.array(42), jnp.array(AnEnum.A)) np.testing.assert_equal(np.int32(101), np.int32(AnEnum.B)) np.testing.assert_equal(jnp.int32(101), jnp.int32(AnEnum.B))