From c6b456f50df13f83b2b64b76f851b08fa2fe8eb0 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 8 May 2020 08:13:09 +0300 Subject: [PATCH] Added argument check to all primitives. * Added argument check to all primitives. The issue that inspired this is that `lax.tie_in` is easy to misuse if the first argument is not a JAX type, then it silently disappears. This means that `lax.tie_in((x, x), const)` is the same as `const` even though `x` is a tracer. This error would be caught previously if core.skip_checks == False because then `bind` checks its arguments. I have essentially added an unconditional argument check to `bind`. In case this is considered too inefficient, we can add argument checking to individual primivites, e.g., tie_in. For most primitives if a non-JAX array is passed, the `impl` rule would fire and `numpy` would report the error somehow, perhaps. * Merged find_top_trace with check_args This was previously merged as #2948 but reverted awaiting the fixes in some user code. --- jax/api.py | 1 + jax/core.py | 28 +++++++++++++++++----------- jax/lax/lax.py | 6 +++++- tests/lax_test.py | 6 +++++- 4 files changed, 28 insertions(+), 13 deletions(-) diff --git a/jax/api.py b/jax/api.py index 4bb1008671ab..72c2657df1a3 100644 --- a/jax/api.py +++ b/jax/api.py @@ -1682,6 +1682,7 @@ def _check_arg(arg): raise TypeError("Argument '{}' of type {} is not a valid JAX type" .format(arg, type(arg))) +# TODO(necula): this duplicates code in core.valid_jaxtype def _valid_jaxtype(arg): try: xla.abstractify(arg) # faster than core.get_aval diff --git a/jax/core.py b/jax/core.py index 94de3b4e949a..9c7eb3a6838f 100644 --- a/jax/core.py +++ b/jax/core.py @@ -17,7 +17,7 @@ from operator import attrgetter from contextlib import contextmanager from collections import namedtuple -from functools import total_ordering +from functools import total_ordering, reduce import itertools as it from weakref import ref import threading @@ -204,8 +204,6 @@ def __repr__(self): return '{}'.format(self.name) def bind(self, *args, **kwargs): - assert skip_checks or all(isinstance(arg, Tracer) - or valid_jaxtype(arg) for arg in args), args top_trace = find_top_trace(args) if top_trace is None: return self.impl(*args, **kwargs) @@ -630,14 +628,22 @@ def full_lower(val): else: return val -def find_top_trace(xs): - try: - top_trace = max((x._trace for x in xs if isinstance(x, Tracer)), - key=attrgetter('level')) - except ValueError: - return None - else: - return type(top_trace)(top_trace.master, cur_sublevel()) +def find_top_trace(args) -> Optional[Tracer]: + """Find the tracer with the highest-level, or None. """ + def check_arg(top_so_far: Optional[Tracer], arg) -> Optional[Tracer]: + if isinstance(arg, Tracer): + return (top_so_far + if top_so_far and top_so_far.level >= arg._trace.level else arg._trace) + # Raises error here for bind on LAX primitives + if not valid_jaxtype(arg): + raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type") + return top_so_far + + top_trace = reduce(check_arg, args, None) # type: ignore[wrong-arg-types] + if top_trace is not None: + return type(top_trace)(top_trace.master, cur_sublevel()) # type: ignore[call-arg] + else: + return None @contextmanager def initial_style_staging(): diff --git a/jax/lax/lax.py b/jax/lax/lax.py index 3918c21f21cc..a022baffb9a1 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -1215,7 +1215,7 @@ def top_k(operand: Array, k: int) -> Tuple[Array, Array]: return top_k_p.bind(operand, k=k) def tie_in(x: Array, y: Array) -> Array: - """Gives ``y`` a fake data dependence on ``x``. + """Returns the value of ``y`` but with a fake data dependence on ``x``. When staging to XLA (e.g. running under jit or pmap), values that don't depend on computation inputs are computed op-by-op, and folded into the XLA @@ -1225,6 +1225,10 @@ def tie_in(x: Array, y: Array) -> Array: When staging to XLA and ``x`` is already staged, then the result of ``tie_in`` is ``y``, but staged to XLA. Downstream use of the result will also be staged to XLA. + + For example, ``lax.sin(const)`` would be constant-folded if ``const`` is + a constant array, but ``lax.sin(lax.tie_in(x, const))``, will be staged to + XLA as long as ``x`` is staged to XLA. """ return tie_in_p.bind(x, y) diff --git a/tests/lax_test.py b/tests/lax_test.py index 5e61406ba8a9..f602bd3575d0 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2711,7 +2711,7 @@ def f2(x, y): expected = onp.array(0.0) self.assertAllClose(ans, expected, check_dtypes=False) - with self.assertRaises(TypeError if core.skip_checks else AssertionError): + with self.assertRaises(TypeError): lax.stop_gradient(lambda x: x) # TODO(mattjj): make this a more systematic test @@ -3358,6 +3358,10 @@ def testSort(self, shape, dimension, arity, bdims): # TODO Collapse # TODO Scatter + def test_tie_in_error(self): + with self.assertRaisesRegex(TypeError, + ".*tuple.* is not a valid JAX type"): + api.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.) if __name__ == '__main__': absltest.main()