Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

only maximally stage out for some call primitives #2834

Merged
merged 1 commit into from
Apr 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ def process_env_traces(post_processor: str, primitive: Primitive,
yield outs, tuple(todo) # Ensure the aux output is immutable

def _call_bind(processor: str, post_processor: str, primitive: Primitive,
f: lu.WrappedFun, *args, **params):
f: lu.WrappedFun, *args, **params):
top_trace = find_top_trace(args)
level = trace_state.trace_stack.next_level(True) if top_trace is None else top_trace.level
params_tuple = tuple(params.items())
Expand Down
4 changes: 3 additions & 1 deletion jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def default_process_primitive(self, primitive, tracers, params):

def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
name = params.get('name', f.__name__)
if self.master.trace_type is StagingJaxprTrace:
if (self.master.trace_type is StagingJaxprTrace
and call_primitive in staged_out_calls):
tracers = map(self.instantiate_const_abstracted, tracers)
else:
name = wrap_name(name, 'pe')
Expand Down Expand Up @@ -312,6 +313,7 @@ def _unmapped_aval(size, aval):

custom_partial_eval_rules: Dict[core.Primitive, Callable] = {}
call_partial_eval_rules: Dict[core.Primitive, Callable] = {}
staged_out_calls: Set[core.Primitive] = set()


def partial_eval(f, trace, pvs: Sequence[Optional[AbstractValue]], instantiate=False):
Expand Down
1 change: 1 addition & 0 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,7 @@ def _get_device(device, backend):
xla_call = partial(core.call_bind, xla_call_p)
xla_call_p.def_custom_bind(xla_call)
xla_call_p.def_impl(_xla_call_impl)
pe.staged_out_calls.add(xla_call_p)

def _xla_call_translation_rule(c, axis_env,
in_nodes, name_stack, backend, name,
Expand Down
10 changes: 10 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,6 +1510,16 @@ def scanned_f(x, _):

jax.grad(scan_bug)(1.0) # doesn't crash

def test_remat_jit_static_argnum(self):
# https://github.com/google/jax/issues/2833
def f(a_bool, y):
if a_bool:
return y + 1
else:
return y

api.jit(api.remat(f, concrete=True), static_argnums=0)(True, 1) # no crash

def test_trivial_computations(self):
x = np.array([1, 2, 3])
y = api.jit(lambda x: x)(x)
Expand Down