Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

factor out process_map / post_process_map #2788

Merged
merged 3 commits into from
Apr 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,8 +942,8 @@ def canonicalize_shape(shape):
"smaller subfunctions.")
raise TypeError(msg.format(shape))

# ------------------- Call -------------------

# ------------------- Call and map -------------------

def apply_todos(todos, outs):
todos_list = list(todos)
Expand All @@ -952,7 +952,8 @@ def apply_todos(todos, outs):
return outs

@lu.transformation_with_aux
def process_env_traces(primitive, level, params_tuple, *args):
def process_env_traces(post_processor: str, primitive: Primitive,
level: int, params_tuple: tuple, *args):
outs = yield args, {}
params = dict(params_tuple)
todo = []
Expand All @@ -964,29 +965,34 @@ def process_env_traces(primitive, level, params_tuple, *args):
break
trace = type(ans._trace)(ans._trace.master, cur_sublevel())
outs = map(trace.full_raise, outs)
outs, cur_todo = trace.post_process_call(primitive, outs, params)
post_process = getattr(trace, post_processor)
outs, cur_todo = post_process(primitive, outs, params)
todo.append(cur_todo)
yield outs, tuple(todo) # Ensure the aux output is immutable

def call_bind(primitive, f: lu.WrappedFun, *args, **params):
def _call_bind(processor: str, post_processor: str, primitive: Primitive,
f: lu.WrappedFun, *args, **params):
top_trace = find_top_trace(args)
level = trace_state.trace_stack.next_level(True) if top_trace is None else top_trace.level
params_tuple = tuple(params.items())
f, env_trace_todo = process_env_traces(f, primitive, level, params_tuple)
f, env_trace_todo = process_env_traces(f, post_processor, primitive, level, params_tuple)
if top_trace is None:
with new_sublevel():
outs = primitive.impl(f, *args, **params)
else:
tracers = map(top_trace.full_raise, args)
outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
process = getattr(top_trace, processor)
outs = map(full_lower, process(primitive, f, tracers, params))
return apply_todos(env_trace_todo(), outs)

call_bind = partial(_call_bind, 'process_call', 'post_process_call')
map_bind = partial(_call_bind, 'process_map', 'post_process_map')


def call_impl(f: lu.WrappedFun, *args, **params):
del params # params parameterize the call primitive, not the function
return f.call_wrapped(*args)


call_p = Primitive('call')
call_p.multiple_results = True
call_p.call_primitive = True
Expand Down
3 changes: 3 additions & 0 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,9 @@ def todo(x):
return map(partial(JVPTracer, trace), primals, tangents)
return out, todo

process_map = process_call
post_process_map = post_process_call

def process_custom_jvp_call(self, _, __, f_jvp, tracers):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
primals_in = map(core.full_lower, primals_in)
Expand Down
30 changes: 18 additions & 12 deletions jax/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,21 @@ def process_primitive(self, primitive, tracers, params):
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
assert call_primitive.multiple_results
params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap'))
if call_primitive in pe.map_primitives:
return self.process_map(call_primitive, f, tracers, params)
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
if all(bdim is not_mapped for bdim in dims):
return call_primitive.bind(f, *vals, **params)
else:
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
if all(bdim is not_mapped for bdim in dims):
return call_primitive.bind(f, *vals, **params)
else:
f, dims_out = batch_subtrace(f, self.master, dims)
vals_out = call_primitive.bind(f, *vals, **params)
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
f, dims_out = batch_subtrace(f, self.master, dims)
vals_out = call_primitive.bind(f, *vals, **params)
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]

def post_process_call(self, call_primitive, out_tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
master = self.master
def todo(vals):
trace = BatchTrace(master, core.cur_sublevel())
return map(partial(BatchTracer, trace), vals, dims)
return vals, todo

def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
Expand All @@ -166,12 +171,13 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
dims_out = tuple(d + 1 if d is not not_mapped else d for d in dims_out())
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)]

def post_process_call(self, call_primitive, out_tracers, params):
def post_process_map(self, call_primitive, out_tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
master = self.master
def todo(x):
def todo(vals):
trace = BatchTrace(master, core.cur_sublevel())
return map(partial(BatchTracer, trace), x, dims)
return [BatchTracer(trace, v, d + 1 if d is not not_mapped else d)
for v, d in zip(vals, dims)]
return vals, todo

def process_custom_jvp_call(self, prim, fun, jvp, tracers):
Expand Down
5 changes: 0 additions & 5 deletions jax/interpreters/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ def process_primitive(self, primitive, tracers, params):
return PapplyTracer(self, name, size, val_out, axis_out)

def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
if call_primitive in pe.map_primitives:
return self.process_map(call_primitive, f, tracers, params)
names, vals, axes = unzip3((t.name, t.val, t.axis) for t in tracers)
if all(axis is not_sharded for axis in axes):
return call_primitive.bind(f, *vals, **params)
Expand All @@ -133,8 +131,5 @@ def todo(x):
return PapplyTracer(trace, name, size, x, axis)
return val, todo

def process_map(self, map_primitive, f :lu.WrappedFun, tracers, params):
raise NotImplementedError # TODO(mattjj,frostig)


papply_primitive_rules: Dict[core.Primitive, Callable] = {}
61 changes: 32 additions & 29 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,9 @@ def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
else:
name = wrap_name(name, 'pe')
params = dict(params, name=name)

if call_primitive in call_partial_eval_rules:
return call_partial_eval_rules[call_primitive](self, call_primitive, f, tracers, params)
if call_primitive in map_primitives:
return self.process_map(call_primitive, f, tracers, params)
in_pvs, in_consts = unzip2([t.pval for t in tracers])
fun, aux = partial_eval(f, self, in_pvs)
out_flat = call_primitive.bind(fun, *in_consts, **params)
Expand All @@ -190,7 +189,38 @@ def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
t.recipe = eqn
return out_tracers

def post_process_call(self, call_primitive, out_tracers, params):
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
out = out_pv_consts + consts
del consts, out_pv_consts
master = self.master
def todo(x):
n = len(jaxpr.outvars)
out_pv_consts, consts = x[:n], x[n:]
trace = JaxprTrace(master, core.cur_sublevel())
const_tracers = map(trace.new_instantiated_const, consts)
env_tracers = map(trace.full_raise, env)
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
new_params = dict(params, call_jaxpr=lifted_jaxpr)
# The `jaxpr` already contains the env_vars at start of invars
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers)),
out_tracers, call_primitive, new_params)
for t in out_tracers:
t.recipe = eqn
return out_tracers
return out, todo

def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
name = params.get('name', f.__name__)
if self.master.trace_type is StagingJaxprTrace:
tracers = map(self.instantiate_const_abstracted, tracers)
else:
name = wrap_name(name, 'pe')

params = dict(params, name=name)
in_pvs, in_consts = unzip2([t.pval for t in tracers])
reduced_pvs = [None if pv is None else _mapped_aval(pv) for pv in in_pvs]
fun, aux = partial_eval(f, self, reduced_pvs)
Expand All @@ -216,32 +246,6 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
t.recipe = eqn
return out_tracers

def post_process_call(self, call_primitive, out_tracers, params):
if call_primitive in map_primitives:
return self.post_process_map(call_primitive, out_tracers, params)
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
out = out_pv_consts + consts
del consts, out_pv_consts
master = self.master
def todo(x):
n = len(jaxpr.outvars)
out_pv_consts, consts = x[:n], x[n:]
trace = JaxprTrace(master, core.cur_sublevel())
const_tracers = map(trace.new_instantiated_const, consts)
env_tracers = map(trace.full_raise, env)
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
new_params = dict(params, call_jaxpr=lifted_jaxpr)
# The `jaxpr` already contains the env_vars at start of invars
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers)),
out_tracers, call_primitive, new_params)
for t in out_tracers:
t.recipe = eqn
return out_tracers
return out, todo

def post_process_map(self, map_primitive, out_tracers, params):
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
out_pvs_reduced, out_pv_consts = unzip2(t.pval for t in out_tracers)
Expand Down Expand Up @@ -306,7 +310,6 @@ def _unmapped_aval(size, aval):
else:
raise TypeError(aval)

map_primitives: Set[core.Primitive] = set()
custom_partial_eval_rules: Dict[core.Primitive, Callable] = {}
call_partial_eval_rules: Dict[core.Primitive, Callable] = {}

Expand Down
20 changes: 9 additions & 11 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def execute_replicated(compiled, backend, in_handler, out_handler, *args):
xla_pmap_p = core.Primitive('xla_pmap')
xla_pmap_p.call_primitive = True
xla_pmap_p.multiple_results = True
xla_pmap = partial(core.call_bind, xla_pmap_p)
xla_pmap = partial(core.map_bind, xla_pmap_p)
xla_pmap_p.def_custom_bind(xla_pmap)
xla_pmap_p.def_impl(xla_pmap_impl)

Expand Down Expand Up @@ -847,7 +847,6 @@ def _pmap_translation_rule(c, axis_env,

xla.call_translations[xla_pmap_p] = _pmap_translation_rule
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
pe.map_primitives.add(xla_pmap_p)

def _xla_shard(c, aval, axis_env, x):
if aval is core.abstract_unit:
Expand Down Expand Up @@ -1019,16 +1018,13 @@ def new_tracer(x, a):

def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
assert call_primitive.multiple_results
if call_primitive in pe.map_primitives:
return self.process_map(call_primitive, f, tracers, params)
vals, names = unzip2((t.val, t.axis_name) for t in tracers)
if all(name is not_mapped for name in names):
return call_primitive.bind(f, *vals, **params)
else:
vals, names = unzip2((t.val, t.axis_name) for t in tracers)
if all(name is not_mapped for name in names):
return call_primitive.bind(f, *vals, **params)
else:
f, names_out = split_axis_subtrace(f, self.master, names)
vals_out = call_primitive.bind(f, *vals, **params)
return [SplitAxisTracer(self, a, x) for a, x in zip(names_out(), vals_out)]
f, names_out = split_axis_subtrace(f, self.master, names)
vals_out = call_primitive.bind(f, *vals, **params)
return [SplitAxisTracer(self, a, x) for a, x in zip(names_out(), vals_out)]

def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
vals, names = unzip2((t.val, t.axis_name) for t in tracers)
Expand All @@ -1054,5 +1050,7 @@ def todo(x):
return SplitAxisTracer(trace, name, x)
return val, todo

post_process_map = post_process_call


split_axis_rules: Dict[core.Primitive, Callable] = {}
21 changes: 21 additions & 0 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,27 @@ def f(args_list):
np.array([sum(vals)] * ndevices),
check_dtypes=True)

def testPostProcessMap(self):
# code from https://github.com/google/jax/issues/2787
def vv(x, y):
"""Vector-vector multiply"""
return np.dot(x, y)

def distributed_matrix_vector(x, y):
"""Matrix vector multiply. First batch it and then row by row"""
fv = lambda z: lax.map(lambda j: vv(j, y), z)
res = pmap(fv)(x.reshape((jax.device_count(), -1) + tuple(x.shape[1:])))
res = res.reshape(res.shape[0] * res.shape[1], *res.shape[2:])
return res

key = random.PRNGKey(1)
x = random.normal(key, (800, 50))
batched_mvm = vmap(lambda b: distributed_matrix_vector(x, b), in_axes=0)
y = random.normal(key, (10, 50, 1))
result = batched_mvm(y)
expected = np.einsum('ij,njk->nik', x, y)
self.assertAllClose(result, expected, check_dtypes=False, atol=1e-3, rtol=1e-3)


class PmapWithDevicesTest(jtu.JaxTestCase):

Expand Down