Skip to content

Commit

Permalink
revisions to #3197 (#3264)
Browse files Browse the repository at this point in the history
revert find_top_trace change from #3197

The previous version was written and tested for performance; the revised
version caused at least a 25% slowdown in the dispatch time of
`lax.add(1, 2)` (and so likely a much bigger slowdown for the
find_top_trace timing alone).

Instead, we can just change the error message in xla.abstractify, since
invalid types lead to abstractification errors when we apply primitive
impls.
  • Loading branch information
mattjj authored Jun 1, 2020
1 parent 7df8375 commit 49a441f
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 26 deletions.
30 changes: 12 additions & 18 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from operator import attrgetter
from contextlib import contextmanager
from collections import namedtuple
from functools import total_ordering, reduce
from functools import total_ordering
import itertools as it
from weakref import ref
import threading
Expand Down Expand Up @@ -219,6 +219,8 @@ def __repr__(self):
return '{}'.format(self.name)

def bind(self, *args, **kwargs):
assert skip_checks or all(isinstance(arg, Tracer)
or valid_jaxtype(arg) for arg in args), args
top_trace = find_top_trace(args)
if top_trace is None:
return self.impl(*args, **kwargs)
Expand Down Expand Up @@ -643,22 +645,10 @@ def full_lower(val):
else:
return val

def find_top_trace(args) -> Optional[Tracer]:
"""Find the tracer with the highest-level, or None. """
def check_arg(top_so_far: Optional[Tracer], arg) -> Optional[Tracer]:
if isinstance(arg, Tracer):
return (top_so_far
if top_so_far and top_so_far.level >= arg._trace.level else arg._trace)
# Raises error here for bind on LAX primitives
if not valid_jaxtype(arg):
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type")
return top_so_far

top_trace = reduce(check_arg, args, None) # type: ignore[wrong-arg-types]
if top_trace is not None:
return type(top_trace)(top_trace.master, cur_sublevel()) # type: ignore[call-arg]
else:
return None
def find_top_trace(xs):
top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
key=attrgetter('level'), default=None)
return top_trace and type(top_trace)(top_trace.master, cur_sublevel())

@contextmanager
def initial_style_staging():
Expand Down Expand Up @@ -724,12 +714,16 @@ def valid_jaxtype(x):
else:
return True

def check_valid_jaxtype(x):
if not valid_jaxtype(x):
raise TypeError(f"{x} of type {type(x)} is not a valid JAX type")


def concrete_aval(x):
for typ in type(x).mro():
handler = pytype_aval_mappings.get(typ)
if handler: return handler(x)
raise TypeError(f"{type(x)} is not a valid Jax type")
raise TypeError(f"{type(x)} is not a valid JAX type")


def get_aval(x):
Expand Down
2 changes: 1 addition & 1 deletion jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def abstractify(x) -> core.AbstractValue:
for typ in typ.mro():
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
raise TypeError(f"No abstraction handler for type: {type(x)}")
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")

def _make_abstract_python_scalar(typ, _):
return ShapedArray((), dtypes.python_scalar_dtypes[typ], weak_type=True)
Expand Down
7 changes: 6 additions & 1 deletion jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4958,8 +4958,13 @@ def _tie_in_batch_rule(batched_args, batch_dims):
_, bdim_y = batch_dims
return y, bdim_y

def _tie_in_impl(x, y):
core.check_valid_jaxtype(x)
core.check_valid_jaxtype(y)
return y

tie_in_p = Primitive('tie_in')
tie_in_p.def_impl(lambda x, y: y)
tie_in_p.def_impl(_tie_in_impl)
tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y))
xla.translations[tie_in_p] = lambda c, x, y: y
ad.deflinear(tie_in_p, _tie_in_transpose_rule)
Expand Down
1 change: 1 addition & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,6 +1765,7 @@ def f(a, b):
with self.assertRaisesRegex(ValueError, msg):
g(jnp.ones((1, 1)), b=1)


class JaxprTest(jtu.JaxTestCase):

def test_scalar_literals(self):
Expand Down
21 changes: 15 additions & 6 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,6 +1697,18 @@ def testDynamicUpdateSliceTypeErrors(self):
onp.zeros((2, 2), dtype=onp.float32),
(onp.int32(1), onp.int16(2))))

def test_tie_in_error(self):
with core.skipping_checks():
with self.assertRaisesRegex(
TypeError, ".* of type .*tuple.* is not a valid JAX type"):
api.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.)

def test_primitive_jaxtype_error(self):
with core.skipping_checks():
with self.assertRaisesRegex(
TypeError, "Argument .* of type .* is not a valid JAX type"):
lax.add(1, 'hi')

class LazyConstantTest(jtu.JaxTestCase):
def _Check(self, make_const, expected):
# check casting to ndarray works
Expand Down Expand Up @@ -2746,8 +2758,9 @@ def f2(x, y):
expected = onp.array(0.0)
self.assertAllClose(ans, expected, check_dtypes=False)

with self.assertRaises(TypeError):
lax.stop_gradient(lambda x: x)
with core.skipping_checks():
with self.assertRaises(TypeError):
lax.stop_gradient(lambda x: x)

# TODO(mattjj): make this a more systematic test
def testRemainder(self):
Expand Down Expand Up @@ -3417,10 +3430,6 @@ def testSort(self, shape, dimension, arity, bdims):
# TODO Collapse
# TODO Scatter

def test_tie_in_error(self):
with self.assertRaisesRegex(TypeError,
".*tuple.* is not a valid JAX type"):
api.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.)

if __name__ == '__main__':
absltest.main()

0 comments on commit 49a441f

Please sign in to comment.