Skip to content

Commit

Permalink
Undo strict checking of LAX primitives (jax-ml#2996)
Browse files Browse the repository at this point in the history
This undoes d08dec5d20
  • Loading branch information
gnecula authored and Jamie Townsend committed May 14, 2020
1 parent e0a8f70 commit 086ccdf
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 28 deletions.
1 change: 0 additions & 1 deletion jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,7 +1574,6 @@ def _check_args(args):
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
.format(arg, type(arg)))

# TODO(necula): this duplicates code in core.valid_jaxtype
def _valid_jaxtype(arg):
try:
xla.abstractify(arg) # faster than core.get_aval
Expand Down
28 changes: 11 additions & 17 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from operator import attrgetter
from contextlib import contextmanager
from collections import namedtuple
from functools import total_ordering, reduce
from functools import total_ordering
import itertools as it
from weakref import ref
import threading
Expand Down Expand Up @@ -204,6 +204,8 @@ def __repr__(self):
return '{}'.format(self.name)

def bind(self, *args, **kwargs):
assert skip_checks or all(isinstance(arg, Tracer)
or valid_jaxtype(arg) for arg in args), args
top_trace = find_top_trace(args)
if top_trace is None:
return self.impl(*args, **kwargs)
Expand Down Expand Up @@ -621,22 +623,14 @@ def full_lower(val):
else:
return val

def find_top_trace(args) -> Optional[Tracer]:
"""Find the tracer with the highest-level, or None. """
def check_arg(top_so_far: Optional[Tracer], arg) -> Optional[Tracer]:
if isinstance(arg, Tracer):
return (top_so_far
if top_so_far and top_so_far.level >= arg._trace.level else arg._trace)
# Raises error here for bind on LAX primitives
if not valid_jaxtype(arg):
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type")
return top_so_far

top_trace = reduce(check_arg, args, None) # type: ignore[wrong-arg-types]
if top_trace is not None:
return type(top_trace)(top_trace.master, cur_sublevel()) # type: ignore[call-arg]
else:
return None
def find_top_trace(xs):
try:
top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
key=attrgetter('level'))
except ValueError:
return None
else:
return type(top_trace)(top_trace.master, cur_sublevel())

@contextmanager
def initial_style_staging():
Expand Down
6 changes: 1 addition & 5 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,7 +1199,7 @@ def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
return top_k_p.bind(operand, k=k)

def tie_in(x: Array, y: Array) -> Array:
"""Returns the value of ``y`` but with a fake data dependence on ``x``.
"""Gives ``y`` a fake data dependence on ``x``.
When staging to XLA (e.g. running under jit or pmap), values that don't depend
on computation inputs are computed op-by-op, and folded into the XLA
Expand All @@ -1209,10 +1209,6 @@ def tie_in(x: Array, y: Array) -> Array:
When staging to XLA and ``x`` is already staged, then the result of ``tie_in``
is ``y``, but staged to XLA. Downstream use of the result will also be staged
to XLA.
For example, ``lax.sin(const)`` would be constant-folded if ``const`` is
a constant array, but ``lax.sin(lax.tie_in(x, const))``, will be staged to
XLA as long as ``x`` is staged to XLA.
"""
return tie_in_p.bind(x, y)

Expand Down
6 changes: 1 addition & 5 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2633,7 +2633,7 @@ def f2(x, y):
expected = onp.array(0.0)
self.assertAllClose(ans, expected, check_dtypes=False)

with self.assertRaises(TypeError):
with self.assertRaises(TypeError if core.skip_checks else AssertionError):
lax.stop_gradient(lambda x: x)

# TODO(mattjj): make this a more systematic test
Expand Down Expand Up @@ -3250,10 +3250,6 @@ def testTopK(self, shape, dtype, k, bdims, rng_factory):
# TODO Collapse
# TODO Scatter

def test_tie_in_error(self):
with self.assertRaisesRegex(TypeError,
".*tuple.* is not a valid JAX type"):
api.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.)

if __name__ == '__main__':
absltest.main()

0 comments on commit 086ccdf

Please sign in to comment.