diff --git a/jax/core.py b/jax/core.py index 578f9e2e2060..aa540dcfc726 100644 --- a/jax/core.py +++ b/jax/core.py @@ -17,7 +17,7 @@ from operator import attrgetter from contextlib import contextmanager from collections import namedtuple -from functools import total_ordering, reduce +from functools import total_ordering import itertools as it from weakref import ref import threading @@ -643,22 +643,14 @@ def full_lower(val): else: return val -def find_top_trace(args) -> Optional[Tracer]: - """Find the tracer with the highest-level, or None. """ - def check_arg(top_so_far: Optional[Tracer], arg) -> Optional[Tracer]: - if isinstance(arg, Tracer): - return (top_so_far - if top_so_far and top_so_far.level >= arg._trace.level else arg._trace) - # Raises error here for bind on LAX primitives - if not valid_jaxtype(arg): - raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type") - return top_so_far - - top_trace = reduce(check_arg, args, None) # type: ignore[wrong-arg-types] - if top_trace is not None: - return type(top_trace)(top_trace.master, cur_sublevel()) # type: ignore[call-arg] - else: - return None +def find_top_trace(xs): + try: + top_trace = max((x._trace for x in xs if isinstance(x, Tracer)), + key=attrgetter('level')) + except ValueError: + return None + else: + return type(top_trace)(top_trace.master, cur_sublevel()) @contextmanager def initial_style_staging(): @@ -729,7 +721,7 @@ def concrete_aval(x): for typ in type(x).mro(): handler = pytype_aval_mappings.get(typ) if handler: return handler(x) - raise TypeError(f"{type(x)} is not a valid Jax type") + raise TypeError(f"{type(x)} is not a valid JAX type") def get_aval(x): diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 543a4261f962..8ffcd3d7841b 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -156,7 +156,7 @@ def abstractify(x) -> core.AbstractValue: for typ in typ.mro(): aval_fn = pytype_aval_mappings.get(typ) if aval_fn: return aval_fn(x) - raise TypeError(f"No abstraction handler for type: {type(x)}") + raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type") def _make_abstract_python_scalar(typ, _): return ShapedArray((), dtypes.python_scalar_dtypes[typ], weak_type=True) diff --git a/tests/api_test.py b/tests/api_test.py index 73ac3a1168fe..3c7121f5062c 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1765,6 +1765,12 @@ def f(a, b): with self.assertRaisesRegex(ValueError, msg): g(jnp.ones((1, 1)), b=1) + def test_primitive_jaxtype_error(self): + with self.assertRaisesRegex( + TypeError, "Argument .* of type .* is not a valid JAX type"): + lax.add(1, 'hi') + + class JaxprTest(jtu.JaxTestCase): def test_scalar_literals(self):