Skip to content

Commit

Permalink
revert find_top_trace change from #3197
Browse files Browse the repository at this point in the history
The previous version was written and tested for performance; the revised
version caused at least a 25% slowdown in the dispatch time of
`lax.add(1, 2)` (and so likely a much bigger slowdown for the
find_top_trace timing alone).

Instead, we can just change the error message in xla.abstractify, since
invalid types lead to abstractification errors when we apply primitive
impls.
  • Loading branch information
mattjj committed May 31, 2020
1 parent d34deba commit 1930093
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 19 deletions.
28 changes: 10 additions & 18 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from operator import attrgetter
from contextlib import contextmanager
from collections import namedtuple
from functools import total_ordering, reduce
from functools import total_ordering
import itertools as it
from weakref import ref
import threading
Expand Down Expand Up @@ -643,22 +643,14 @@ def full_lower(val):
else:
return val

def find_top_trace(args) -> Optional[Tracer]:
"""Find the tracer with the highest-level, or None. """
def check_arg(top_so_far: Optional[Tracer], arg) -> Optional[Tracer]:
if isinstance(arg, Tracer):
return (top_so_far
if top_so_far and top_so_far.level >= arg._trace.level else arg._trace)
# Raises error here for bind on LAX primitives
if not valid_jaxtype(arg):
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type")
return top_so_far

top_trace = reduce(check_arg, args, None) # type: ignore[wrong-arg-types]
if top_trace is not None:
return type(top_trace)(top_trace.master, cur_sublevel()) # type: ignore[call-arg]
else:
return None
def find_top_trace(xs):
try:
top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
key=attrgetter('level'))
except ValueError:
return None
else:
return type(top_trace)(top_trace.master, cur_sublevel())

@contextmanager
def initial_style_staging():
Expand Down Expand Up @@ -729,7 +721,7 @@ def concrete_aval(x):
for typ in type(x).mro():
handler = pytype_aval_mappings.get(typ)
if handler: return handler(x)
raise TypeError(f"{type(x)} is not a valid Jax type")
raise TypeError(f"{type(x)} is not a valid JAX type")


def get_aval(x):
Expand Down
2 changes: 1 addition & 1 deletion jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def abstractify(x) -> core.AbstractValue:
for typ in typ.mro():
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
raise TypeError(f"No abstraction handler for type: {type(x)}")
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")

def _make_abstract_python_scalar(typ, _):
return ShapedArray((), dtypes.python_scalar_dtypes[typ], weak_type=True)
Expand Down
6 changes: 6 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,6 +1765,12 @@ def f(a, b):
with self.assertRaisesRegex(ValueError, msg):
g(jnp.ones((1, 1)), b=1)

def test_primitive_jaxtype_error(self):
with self.assertRaisesRegex(
TypeError, "Argument .* of type .* is not a valid JAX type"):
lax.add(1, 'hi')


class JaxprTest(jtu.JaxTestCase):

def test_scalar_literals(self):
Expand Down

0 comments on commit 1930093

Please sign in to comment.