Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve a ConcretizationTypeError message from dependence on jitted function arguments #4342

Merged
merged 7 commits into from
Sep 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,8 +846,6 @@ def raise_concretization_error(val: Tracer, context=""):
msg = ("Abstract tracer value encountered where concrete value is expected.\n\n"
+ context + "\n\n"
+ val._origin_msg() + "\n\n"
+ "You can use transformation parameters such as `static_argnums` for "
"`jit` to avoid tracing particular arguments of transformed functions.\n\n"
"See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
f"Encountered tracer value: {val}")
raise ConcretizationTypeError(msg)
Expand Down
33 changes: 24 additions & 9 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,14 +800,24 @@ def _contents(self):
return ()

def _origin_msg(self):
progenitor_eqns = self._trace.frame.find_progenitors(self)
msgs = [f" operation {core.pp_eqn(eqn, print_shapes=True)}\n"
f" from line {source_info_util.summarize(eqn.source_info)}"
for eqn in progenitor_eqns]
if msgs:
invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self)
if invar_pos:
origin = (f"While tracing the function {self._trace.main.source_info}, "
"this concrete value was not available in Python because it "
"depends on the value of the arguments to "
f"{self._trace.main.source_info} at flattened positions {invar_pos}, "
"and the computation of these values is being staged out "
"(that is, delayed rather than executed eagerly).\n\n"
"You can use transformation parameters such as `static_argnums` "
"for `jit` to avoid tracing particular arguments of transformed "
"functions, though at the cost of more recompiles.")
elif progenitor_eqns:
msts = [f" operation {core.pp_eqn(eqn, print_shapes=True)}\n"
f" from line {source_info_util.summarize(eqn.source_info)}"
for eqn in progenitor_eqns]
origin = (f"While tracing the function {self._trace.main.source_info}, "
"this value became a tracer due to JAX operations on these lines:"
"\n\n" + "\n\n".join(msgs))
"\n\n" + "\n\n".join(msts))
else:
origin = ("The error occured while tracing the function "
f"{self._trace.main.source_info}.")
Expand All @@ -820,7 +830,7 @@ def _assert_live(self) -> None:

class JaxprStackFrame:
__slots__ = ['newvar', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
'tracers', 'eqns']
'tracers', 'eqns', 'invars']

def __init__(self):
self.newvar = core.gensym()
Expand All @@ -829,6 +839,7 @@ def __init__(self):
self.constvar_to_val = {}
self.tracers = [] # circ refs, frame->tracer->trace->main->frame,
self.eqns = [] # cleared when we pop frame from main
self.invars = []

def to_jaxpr(self, in_tracers, out_tracers):
invars = [self.tracer_to_var[id(t)] for t in in_tracers]
Expand All @@ -850,7 +861,10 @@ def find_progenitors(self, tracer):
if produced:
active_vars.difference_update(produced)
active_vars.update(eqn.invars)
return [eqn for eqn in self.eqns if set(eqn.invars) & active_vars]
invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars]
constvars = active_vars & set(self.constvar_to_val)
const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars]
return invar_positions, const_eqns

def _inline_literals(jaxpr, constvals):
consts = dict(zip(jaxpr.constvars, constvals))
Expand Down Expand Up @@ -890,7 +904,8 @@ def frame(self):
def new_arg(self, aval):
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
self.frame.tracers.append(tracer)
self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(aval)
self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval)
self.frame.invars.append(var)
return tracer

def new_const(self, val):
Expand Down
27 changes: 27 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,6 +1797,33 @@ def f():

f() # doesn't crash

def test_concrete_error_because_arg(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")

@jax.jit
def f(x, y):
if x > y:
return x
else:
return y

msg = r"at flattened positions \[0, 1\]"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2)

def test_concrete_error_because_const(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")

@jax.jit
def f():
assert jnp.add(1, 1) > 0

msg = "on these lines"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f()

def test_xla_computation_zeros_doesnt_device_put(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")
Expand Down