Skip to content

Commit

Permalink
Added argument check to all primitives. (jax-ml#3197)
Browse files Browse the repository at this point in the history
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previously if core.skip_checks == False
because then `bind` checks its arguments. I have essentially added
an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and `numpy`
would report the error somehow, perhaps.

* Merged find_top_trace with check_args

This was previously merged as jax-ml#2948 but reverted awaiting the fixes
in some user code.
  • Loading branch information
gnecula authored and NeilGirdhar committed Jun 11, 2020
1 parent 6c711b5 commit 3dd07b3
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 13 deletions.
1 change: 1 addition & 0 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,6 +1682,7 @@ def _check_arg(arg):
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
.format(arg, type(arg)))

# TODO(necula): this duplicates code in core.valid_jaxtype
def _valid_jaxtype(arg):
try:
xla.abstractify(arg) # faster than core.get_aval
Expand Down
28 changes: 17 additions & 11 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from operator import attrgetter
from contextlib import contextmanager
from collections import namedtuple
from functools import total_ordering
from functools import total_ordering, reduce
import itertools as it
from weakref import ref
import threading
Expand Down Expand Up @@ -204,8 +204,6 @@ def __repr__(self):
return '{}'.format(self.name)

def bind(self, *args, **kwargs):
assert skip_checks or all(isinstance(arg, Tracer)
or valid_jaxtype(arg) for arg in args), args
top_trace = find_top_trace(args)
if top_trace is None:
return self.impl(*args, **kwargs)
Expand Down Expand Up @@ -630,14 +628,22 @@ def full_lower(val):
else:
return val

def find_top_trace(xs):
try:
top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
key=attrgetter('level'))
except ValueError:
return None
else:
return type(top_trace)(top_trace.master, cur_sublevel())
def find_top_trace(args) -> Optional[Tracer]:
"""Find the tracer with the highest-level, or None. """
def check_arg(top_so_far: Optional[Tracer], arg) -> Optional[Tracer]:
if isinstance(arg, Tracer):
return (top_so_far
if top_so_far and top_so_far.level >= arg._trace.level else arg._trace)
# Raises error here for bind on LAX primitives
if not valid_jaxtype(arg):
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type")
return top_so_far

top_trace = reduce(check_arg, args, None) # type: ignore[wrong-arg-types]
if top_trace is not None:
return type(top_trace)(top_trace.master, cur_sublevel()) # type: ignore[call-arg]
else:
return None

@contextmanager
def initial_style_staging():
Expand Down
6 changes: 5 additions & 1 deletion jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,7 @@ def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
return top_k_p.bind(operand, k=k)

def tie_in(x: Array, y: Array) -> Array:
"""Gives ``y`` a fake data dependence on ``x``.
"""Returns the value of ``y`` but with a fake data dependence on ``x``.
When staging to XLA (e.g. running under jit or pmap), values that don't depend
on computation inputs are computed op-by-op, and folded into the XLA
Expand All @@ -1225,6 +1225,10 @@ def tie_in(x: Array, y: Array) -> Array:
When staging to XLA and ``x`` is already staged, then the result of ``tie_in``
is ``y``, but staged to XLA. Downstream use of the result will also be staged
to XLA.
For example, ``lax.sin(const)`` would be constant-folded if ``const`` is
a constant array, but ``lax.sin(lax.tie_in(x, const))``, will be staged to
XLA as long as ``x`` is staged to XLA.
"""
return tie_in_p.bind(x, y)

Expand Down
6 changes: 5 additions & 1 deletion tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2711,7 +2711,7 @@ def f2(x, y):
expected = onp.array(0.0)
self.assertAllClose(ans, expected, check_dtypes=False)

with self.assertRaises(TypeError if core.skip_checks else AssertionError):
with self.assertRaises(TypeError):
lax.stop_gradient(lambda x: x)

# TODO(mattjj): make this a more systematic test
Expand Down Expand Up @@ -3358,6 +3358,10 @@ def testSort(self, shape, dimension, arity, bdims):
# TODO Collapse
# TODO Scatter

def test_tie_in_error(self):
with self.assertRaisesRegex(TypeError,
".*tuple.* is not a valid JAX type"):
api.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.)

if __name__ == '__main__':
absltest.main()

0 comments on commit 3dd07b3

Please sign in to comment.