diff --git a/jax/core.py b/jax/core.py index 578f9e2e2060..374d9abf1d67 100644 --- a/jax/core.py +++ b/jax/core.py @@ -17,7 +17,7 @@ from operator import attrgetter from contextlib import contextmanager from collections import namedtuple -from functools import total_ordering, reduce +from functools import total_ordering import itertools as it from weakref import ref import threading @@ -219,6 +219,8 @@ def __repr__(self): return '{}'.format(self.name) def bind(self, *args, **kwargs): + assert skip_checks or all(isinstance(arg, Tracer) + or valid_jaxtype(arg) for arg in args), args top_trace = find_top_trace(args) if top_trace is None: return self.impl(*args, **kwargs) @@ -643,22 +645,10 @@ def full_lower(val): else: return val -def find_top_trace(args) -> Optional[Tracer]: - """Find the tracer with the highest-level, or None. """ - def check_arg(top_so_far: Optional[Tracer], arg) -> Optional[Tracer]: - if isinstance(arg, Tracer): - return (top_so_far - if top_so_far and top_so_far.level >= arg._trace.level else arg._trace) - # Raises error here for bind on LAX primitives - if not valid_jaxtype(arg): - raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type") - return top_so_far - - top_trace = reduce(check_arg, args, None) # type: ignore[wrong-arg-types] - if top_trace is not None: - return type(top_trace)(top_trace.master, cur_sublevel()) # type: ignore[call-arg] - else: - return None +def find_top_trace(xs): + top_trace = max((x._trace for x in xs if isinstance(x, Tracer)), + key=attrgetter('level'), default=None) + return top_trace and type(top_trace)(top_trace.master, cur_sublevel()) @contextmanager def initial_style_staging(): @@ -724,12 +714,16 @@ def valid_jaxtype(x): else: return True +def check_valid_jaxtype(x): + if not valid_jaxtype(x): + raise TypeError(f"{x} of type {type(x)} is not a valid JAX type") + def concrete_aval(x): for typ in type(x).mro(): handler = pytype_aval_mappings.get(typ) if handler: return handler(x) - raise TypeError(f"{type(x)} is not a valid Jax type") + raise TypeError(f"{type(x)} is not a valid JAX type") def get_aval(x): diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 543a4261f962..8ffcd3d7841b 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -156,7 +156,7 @@ def abstractify(x) -> core.AbstractValue: for typ in typ.mro(): aval_fn = pytype_aval_mappings.get(typ) if aval_fn: return aval_fn(x) - raise TypeError(f"No abstraction handler for type: {type(x)}") + raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type") def _make_abstract_python_scalar(typ, _): return ShapedArray((), dtypes.python_scalar_dtypes[typ], weak_type=True) diff --git a/jax/lax/lax.py b/jax/lax/lax.py index e27d587df992..1a416c75a95d 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -4958,8 +4958,13 @@ def _tie_in_batch_rule(batched_args, batch_dims): _, bdim_y = batch_dims return y, bdim_y +def _tie_in_impl(x, y): + core.check_valid_jaxtype(x) + core.check_valid_jaxtype(y) + return y + tie_in_p = Primitive('tie_in') -tie_in_p.def_impl(lambda x, y: y) +tie_in_p.def_impl(_tie_in_impl) tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y)) xla.translations[tie_in_p] = lambda c, x, y: y ad.deflinear(tie_in_p, _tie_in_transpose_rule) diff --git a/tests/api_test.py b/tests/api_test.py index 73ac3a1168fe..0d424ba850b9 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1765,6 +1765,7 @@ def f(a, b): with self.assertRaisesRegex(ValueError, msg): g(jnp.ones((1, 1)), b=1) + class JaxprTest(jtu.JaxTestCase): def test_scalar_literals(self): diff --git a/tests/lax_test.py b/tests/lax_test.py index 7e3edfaed444..2db5999e5e3a 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1703,6 +1703,18 @@ def testDynamicUpdateSliceTypeErrors(self): onp.zeros((2, 2), dtype=onp.float32), (onp.int32(1), onp.int16(2)))) + def test_tie_in_error(self): + with core.skipping_checks(): + with self.assertRaisesRegex( + TypeError, ".* of type .*tuple.* is not a valid JAX type"): + api.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.) + + def test_primitive_jaxtype_error(self): + with core.skipping_checks(): + with self.assertRaisesRegex( + TypeError, "Argument .* of type .* is not a valid JAX type"): + lax.add(1, 'hi') + class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected): # check casting to ndarray works @@ -2752,8 +2764,9 @@ def f2(x, y): expected = onp.array(0.0) self.assertAllClose(ans, expected, check_dtypes=False) - with self.assertRaises(TypeError): - lax.stop_gradient(lambda x: x) + with core.skipping_checks(): + with self.assertRaises(TypeError): + lax.stop_gradient(lambda x: x) # TODO(mattjj): make this a more systematic test def testRemainder(self): @@ -3423,10 +3436,6 @@ def testSort(self, shape, dimension, arity, bdims): # TODO Collapse # TODO Scatter - def test_tie_in_error(self): - with self.assertRaisesRegex(TypeError, - ".*tuple.* is not a valid JAX type"): - api.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.) if __name__ == '__main__': absltest.main()