Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

revisions to #3197 #3264

Merged
merged 5 commits into from
Jun 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 12 additions & 18 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from operator import attrgetter
from contextlib import contextmanager
from collections import namedtuple
from functools import total_ordering, reduce
from functools import total_ordering
import itertools as it
from weakref import ref
import threading
Expand Down Expand Up @@ -219,6 +219,8 @@ def __repr__(self):
return '{}'.format(self.name)

def bind(self, *args, **kwargs):
assert skip_checks or all(isinstance(arg, Tracer)
or valid_jaxtype(arg) for arg in args), args
top_trace = find_top_trace(args)
if top_trace is None:
return self.impl(*args, **kwargs)
Expand Down Expand Up @@ -643,22 +645,10 @@ def full_lower(val):
else:
return val

def find_top_trace(args) -> Optional[Tracer]:
"""Find the tracer with the highest-level, or None. """
def check_arg(top_so_far: Optional[Tracer], arg) -> Optional[Tracer]:
if isinstance(arg, Tracer):
return (top_so_far
if top_so_far and top_so_far.level >= arg._trace.level else arg._trace)
# Raises error here for bind on LAX primitives
if not valid_jaxtype(arg):
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type")
return top_so_far

top_trace = reduce(check_arg, args, None) # type: ignore[wrong-arg-types]
if top_trace is not None:
return type(top_trace)(top_trace.master, cur_sublevel()) # type: ignore[call-arg]
else:
return None
def find_top_trace(xs):
top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
key=attrgetter('level'), default=None)
return top_trace and type(top_trace)(top_trace.master, cur_sublevel())

@contextmanager
def initial_style_staging():
Expand Down Expand Up @@ -724,12 +714,16 @@ def valid_jaxtype(x):
else:
return True

def check_valid_jaxtype(x):
if not valid_jaxtype(x):
raise TypeError(f"{x} of type {type(x)} is not a valid JAX type")


def concrete_aval(x):
for typ in type(x).mro():
handler = pytype_aval_mappings.get(typ)
if handler: return handler(x)
raise TypeError(f"{type(x)} is not a valid Jax type")
raise TypeError(f"{type(x)} is not a valid JAX type")


def get_aval(x):
Expand Down
2 changes: 1 addition & 1 deletion jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def abstractify(x) -> core.AbstractValue:
for typ in typ.mro():
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
raise TypeError(f"No abstraction handler for type: {type(x)}")
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")

def _make_abstract_python_scalar(typ, _):
return ShapedArray((), dtypes.python_scalar_dtypes[typ], weak_type=True)
Expand Down
7 changes: 6 additions & 1 deletion jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4958,8 +4958,13 @@ def _tie_in_batch_rule(batched_args, batch_dims):
_, bdim_y = batch_dims
return y, bdim_y

def _tie_in_impl(x, y):
core.check_valid_jaxtype(x)
core.check_valid_jaxtype(y)
return y

tie_in_p = Primitive('tie_in')
tie_in_p.def_impl(lambda x, y: y)
tie_in_p.def_impl(_tie_in_impl)
tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y))
xla.translations[tie_in_p] = lambda c, x, y: y
ad.deflinear(tie_in_p, _tie_in_transpose_rule)
Expand Down
1 change: 1 addition & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,6 +1765,7 @@ def f(a, b):
with self.assertRaisesRegex(ValueError, msg):
g(jnp.ones((1, 1)), b=1)


class JaxprTest(jtu.JaxTestCase):

def test_scalar_literals(self):
Expand Down
21 changes: 15 additions & 6 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,18 @@ def testDynamicUpdateSliceTypeErrors(self):
onp.zeros((2, 2), dtype=onp.float32),
(onp.int32(1), onp.int16(2))))

def test_tie_in_error(self):
with core.skipping_checks():
with self.assertRaisesRegex(
TypeError, ".* of type .*tuple.* is not a valid JAX type"):
api.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.)

def test_primitive_jaxtype_error(self):
with core.skipping_checks():
with self.assertRaisesRegex(
TypeError, "Argument .* of type .* is not a valid JAX type"):
lax.add(1, 'hi')

class LazyConstantTest(jtu.JaxTestCase):
def _Check(self, make_const, expected):
# check casting to ndarray works
Expand Down Expand Up @@ -2752,8 +2764,9 @@ def f2(x, y):
expected = onp.array(0.0)
self.assertAllClose(ans, expected, check_dtypes=False)

with self.assertRaises(TypeError):
lax.stop_gradient(lambda x: x)
with core.skipping_checks():
with self.assertRaises(TypeError):
lax.stop_gradient(lambda x: x)

# TODO(mattjj): make this a more systematic test
def testRemainder(self):
Expand Down Expand Up @@ -3423,10 +3436,6 @@ def testSort(self, shape, dimension, arity, bdims):
# TODO Collapse
# TODO Scatter

def test_tie_in_error(self):
with self.assertRaisesRegex(TypeError,
".*tuple.* is not a valid JAX type"):
api.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.)

if __name__ == '__main__':
absltest.main()