From 1662178d4f4fccdfceca35f4809fd99ea6b4adb3 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 24 Apr 2020 18:03:26 -0700 Subject: [PATCH] only maximally stage out for some call primitives fixes #2833 --- jax/core.py | 2 +- jax/interpreters/partial_eval.py | 4 +++- jax/interpreters/xla.py | 1 + tests/api_test.py | 10 ++++++++++ 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/jax/core.py b/jax/core.py index 60824cbb5c9d..6029109b60f5 100644 --- a/jax/core.py +++ b/jax/core.py @@ -989,7 +989,7 @@ def process_env_traces(post_processor: str, primitive: Primitive, yield outs, tuple(todo) # Ensure the aux output is immutable def _call_bind(processor: str, post_processor: str, primitive: Primitive, - f: lu.WrappedFun, *args, **params): + f: lu.WrappedFun, *args, **params): top_trace = find_top_trace(args) level = trace_state.trace_stack.next_level(True) if top_trace is None else top_trace.level params_tuple = tuple(params.items()) diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index df88fdfb9ec8..7b27c822e664 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -163,7 +163,8 @@ def default_process_primitive(self, primitive, tracers, params): def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): name = params.get('name', f.__name__) - if self.master.trace_type is StagingJaxprTrace: + if (self.master.trace_type is StagingJaxprTrace + and call_primitive in staged_out_calls): tracers = map(self.instantiate_const_abstracted, tracers) else: name = wrap_name(name, 'pe') @@ -312,6 +313,7 @@ def _unmapped_aval(size, aval): custom_partial_eval_rules: Dict[core.Primitive, Callable] = {} call_partial_eval_rules: Dict[core.Primitive, Callable] = {} +staged_out_calls: Set[core.Primitive] = set() def partial_eval(f, trace, pvs: Sequence[Optional[AbstractValue]], instantiate=False): diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index d6051ab5ee4d..fbbe2a0c438e 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -608,6 +608,7 @@ def _get_device(device, backend): xla_call = partial(core.call_bind, xla_call_p) xla_call_p.def_custom_bind(xla_call) xla_call_p.def_impl(_xla_call_impl) +pe.staged_out_calls.add(xla_call_p) def _xla_call_translation_rule(c, axis_env, in_nodes, name_stack, backend, name, diff --git a/tests/api_test.py b/tests/api_test.py index 7ce344275568..f8b6f6f1d026 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1510,6 +1510,16 @@ def scanned_f(x, _): jax.grad(scan_bug)(1.0) # doesn't crash + def test_remat_jit_static_argnum(self): + # https://github.com/google/jax/issues/2833 + def f(a_bool, y): + if a_bool: + return y + 1 + else: + return y + + api.jit(api.remat(f, concrete=True), static_argnums=0)(True, 1) # no crash + def test_trivial_computations(self): x = np.array([1, 2, 3]) y = api.jit(lambda x: x)(x)