Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor call primitives, simpler param processing #3491

Merged
merged 1 commit into from
Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,16 +1150,11 @@ def f_pmapped(*args, **kwargs):
for arg in args: _check_arg(arg)
flat_fun, out_tree = flatten_fun(f, in_tree)
out = pxla.xla_pmap(
flat_fun,
*args,
backend=backend,
axis_name=axis_name,
axis_size=local_axis_size,
global_axis_size=axis_size,
devices=tuple(devices) if devices is not None else devices,
name=flat_fun.__name__,
flat_fun, *args, backend=backend, axis_name=axis_name,
axis_size=local_axis_size, global_axis_size=axis_size,
devices=None if devices is None else tuple(devices),
mapped_invars=tuple(axis is not None for axis in in_axes_flat),
donated_invars=tuple(donated_invars))
name=flat_fun.__name__, donated_invars=tuple(donated_invars))
return tree_unflatten(out_tree(), out)

return f_pmapped
Expand Down
64 changes: 44 additions & 20 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ def canonicalize_shape(shape):
raise TypeError(msg.format(shape))


# ------------------- Call and map -------------------
# ------------------- Call -------------------

def apply_todos(todos, outs):
todos_list = list(todos)
Expand All @@ -1071,7 +1071,7 @@ def apply_todos(todos, outs):
return outs

@lu.transformation_with_aux
def process_env_traces(post_processor: str, primitive: Primitive,
def process_env_traces(primitive: Union['CallPrimitive', 'MapPrimitive'],
level: int, params_tuple: tuple, *args):
outs = yield args, {}
params = dict(params_tuple)
Expand All @@ -1084,41 +1084,59 @@ def process_env_traces(post_processor: str, primitive: Primitive,
break
trace = type(ans._trace)(ans._trace.master, cur_sublevel())
outs = map(trace.full_raise, outs)
post_process = getattr(trace, post_processor)
outs, cur_todo = post_process(primitive, outs, params)
outs, cur_todo = primitive.post_process(trace, outs, params)
todo.append(cur_todo)
yield outs, tuple(todo) # Ensure the aux output is immutable

def _call_bind(processor: str, post_processor: str, primitive: Primitive,
f: lu.WrappedFun, *args, **params):
def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
fun: lu.WrappedFun, *args, **params):
params_tuple = tuple(params.items())
top_trace = find_top_trace(args)
level = trace_state.trace_stack.next_level(True) if top_trace is None else top_trace.level
params_tuple = tuple(params.items())
f, env_trace_todo = process_env_traces(f, post_processor, primitive, level, params_tuple)
fun, env_trace_todo = process_env_traces(fun, primitive, level, params_tuple)
if top_trace is None:
with new_sublevel():
outs = primitive.impl(f, *args, **params)
outs = primitive.impl(fun, *args, **params)
else:
tracers = map(top_trace.full_raise, args)
process = getattr(top_trace, processor)
outs = map(full_lower, process(primitive, f, tracers, params))
outs = primitive.process(top_trace, fun, tracers, params)
return apply_todos(env_trace_todo(), outs)

call_bind = partial(_call_bind, 'process_call', 'post_process_call')
map_bind = partial(_call_bind, 'process_map', 'post_process_map')
class CallPrimitive(Primitive):
multiple_results = True
call_primitive = True
bind = call_bind

def process(self, trace, fun, tracers, params):
return trace.process_call(self, fun, tracers, params)

def post_process(self, trace, out_tracers, params):
return trace.post_process_call(self, out_tracers, params)

def call_impl(f: lu.WrappedFun, *args, **params):
del params # params parameterize the call primitive, not the function
return f.call_wrapped(*args)

call_p = Primitive('call')
call_p.multiple_results = True
call_p.call_primitive = True
call = partial(call_bind, call_p)
call_p.def_custom_bind(call)
call_p = CallPrimitive('call')
call = call_p.bind
call_p.def_impl(call_impl)

# ------------------- Map -------------------

class MapPrimitive(Primitive):
multiple_results = True
map_primitive = True

def bind(self, fun, *args, **params):
assert len(params['mapped_invars']) == len(args)
return call_bind(self, fun, *args, **params)

def process(self, trace, fun, tracers, params):
return trace.process_map(self, fun, tracers, params)

def post_process(self, trace, out_tracers, params):
return trace.post_process_map(self, out_tracers, params)

# ------------------- Jaxpr checking -------------------

Expand Down Expand Up @@ -1168,14 +1186,13 @@ def check_jaxpr(jaxpr: Jaxpr):
try:
_check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars])
except Exception as e:
exception_type = type(e)
msg_context = f"while checking jaxpr:\n\n{jaxpr}\n"
if len(e.args) == 0:
exception_args = [msg_context]
else:
msg = f"{e.args[0]}\n\n" + msg_context
msg = f"{e.args[0]}\n\n{msg_context}"
exception_args = [msg, *e.args[1:]]
raise exception_type(*exception_args) from e
raise type(e)(*exception_args) from e

def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):

Expand Down Expand Up @@ -1203,6 +1220,11 @@ def write(v: Var, a: AbstractValue) -> None:
map(write, jaxpr.invars, in_avals)

for eqn in jaxpr.eqns:

if eqn.primitive in skip_check_primitives:
map(write, eqn.outvars, [v.aval for v in eqn.outvars]) # skip checking
continue

in_avals = map(read, eqn.invars)
if eqn.primitive.call_primitive:
out_avals = check_call(eqn.primitive, in_avals, eqn.params)
Expand All @@ -1218,6 +1240,8 @@ def write(v: Var, a: AbstractValue) -> None:

map(read, jaxpr.outvars)

skip_check_primitives: Set[Primitive] = set()

def check_eqn(prim, in_avals, params):
for jaxpr in jaxprs_in_params(params):
check_jaxpr(jaxpr)
Expand Down
82 changes: 42 additions & 40 deletions jax/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial, update_wrapper, reduce
from functools import update_wrapper, reduce
import inspect
import operator as op

Expand Down Expand Up @@ -262,25 +262,30 @@ def _flatten_jvp(in_tree, *args):
raise TypeError(msg.format('\n'.join(disagreements)))
yield primals_out + tangents_out, out_tree

def _custom_jvp_call_bind(prim, fun, jvp, *args):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
if top_trace is None:
with core.new_sublevel():
outs = prim.impl(fun, jvp, *args)
else:
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_jvp_call(prim, fun, jvp, tracers)
return map(core.full_lower, outs)
class CustomJVPCallPrimitive(core.CallPrimitive):
def bind(self, fun, jvp, *args):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
fun, env_trace_todo1 = core.process_env_traces(
fun, self, top_trace and top_trace.level, ())
jvp, env_trace_todo2 = core.process_env_traces(
jvp, self, top_trace and top_trace.level, ())
if top_trace is None:
with core.new_sublevel():
outs = self.impl(fun, jvp, *args)
else:
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers)
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
if env_trace_todo:
raise core.UnexpectedTracerError
return map(core.full_lower, outs)

def _custom_jvp_call_impl(fun, _, *args):
return fun.call_wrapped(*args)
def impl(self, fun, _, *args):
return fun.call_wrapped(*args)

custom_jvp_call_p = core.Primitive('custom_jvp_call')
custom_jvp_call_p.multiple_results = True
custom_jvp_call = partial(_custom_jvp_call_bind, custom_jvp_call_p)
custom_jvp_call_p.def_custom_bind(custom_jvp_call)
custom_jvp_call_p.def_impl(_custom_jvp_call_impl)
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')
custom_jvp_call = custom_jvp_call_p.bind


def custom_jvp_call_jaxpr(fun, jvp, *args):
Expand Down Expand Up @@ -501,28 +506,25 @@ def _flatten_bwd(in_tree, out_trees, *args):
raise TypeError(msg.format(in_tree2, in_tree)) from None
yield cts_in

def _custom_vjp_call_bind(prim, fun, fwd, bwd, *args, out_trees):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
if top_trace is None:
with core.new_sublevel():
outs = prim.impl(fun, fwd, bwd, *args, out_trees=out_trees)
else:
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_vjp_call(prim, fun, fwd, bwd, tracers,
out_trees=out_trees)
outs = map(core.full_lower, outs)
return map(core.full_lower, outs)

def _custom_vjp_call_impl(fun, fwd, bwd, *args, out_trees):
del fwd, bwd, out_trees # Unused.
return fun.call_wrapped(*args)

custom_vjp_call_p = core.Primitive('custom_vjp_call')
custom_vjp_call_p.multiple_results = True
custom_vjp_call = partial(_custom_vjp_call_bind, custom_vjp_call_p)
custom_vjp_call_p.def_custom_bind(custom_vjp_call)
custom_vjp_call_p.def_impl(_custom_vjp_call_impl)

class CustomVJPCallPrimitive(core.CallPrimitive):
def bind(self, fun, fwd, bwd, *args, out_trees):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
if top_trace is None:
outs = fun.call_wrapped(*args)
else:
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers,
out_trees=out_trees)
return map(core.full_lower, outs)

def impl(self, fun, fwd, bwd, *args, out_trees):
del fwd, bwd, out_trees
return fun.call_wrapped(*args)

custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call')
custom_vjp_call = custom_vjp_call_p.bind

def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees):
in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
Expand Down
12 changes: 9 additions & 3 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,8 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn,
pred1_and_token1,
xla.xla_call_p,
dict(call_jaxpr=transformed_cond_jaxpr.jaxpr,
name="cond_before"),
name="cond_before",
donated_invars=(False,) * (cond_nconsts + len(carry_invars) + 1)),
eqn.source_info))
# Make a new cond "lambda pred, carry, token: pred"
new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0])
Expand Down Expand Up @@ -667,14 +668,19 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn,
new_body_carry2 + [new_body_token2],
xla.xla_call_p,
dict(call_jaxpr=transformed_body_jaxpr.jaxpr,
name="body"),
name="body",
donated_invars=(False,) * (len(new_body_invars_body_constvars) +
len(new_body_invars_carry) +
1 + len(new_body_carry2) + 1)),
eqn.source_info),
core.new_jaxpr_eqn(
new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2],
[new_body_pred2, new_body_token3],
xla.xla_call_p,
dict(call_jaxpr=transformed_cond_jaxpr.jaxpr,
name="cond_body"),
name="cond_body",
donated_invars=(False,) * (len(new_body_invars_cond_constvars) +
len(new_body_carry2) + 1 + 2)),
eqn.source_info)
]
new_body_jaxpr = _mk_typed_jaxpr(
Expand Down
Loading