Skip to content

Commit

Permalink
factor out process_map / post_process_map
Browse files Browse the repository at this point in the history
Also fix a bug from reusing post_process_call for pmap. Fixes #2787
  • Loading branch information
mattjj committed Apr 21, 2020
1 parent 6db1f0c commit 063cf0f
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 60 deletions.
36 changes: 34 additions & 2 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,8 +942,8 @@ def canonicalize_shape(shape):
"smaller subfunctions.")
raise TypeError(msg.format(shape))

# ------------------- Call -------------------

# ------------------- Call and map -------------------

def apply_todos(todos, outs):
todos_list = list(todos)
Expand Down Expand Up @@ -986,7 +986,6 @@ def call_impl(f: lu.WrappedFun, *args, **params):
del params # params parameterize the call primitive, not the function
return f.call_wrapped(*args)


call_p = Primitive('call')
call_p.multiple_results = True
call_p.call_primitive = True
Expand All @@ -995,6 +994,39 @@ def call_impl(f: lu.WrappedFun, *args, **params):
call_p.def_impl(call_impl)


# TODO(mattjj): de-duplicate the next two functions with the above

@lu.transformation_with_aux
def process_env_traces_map(primitive, level, params_tuple, *args):
outs = yield args, {}
params = dict(params_tuple)
todo = []
while True:
tracers = [x for x in outs if isinstance(x, Tracer) and x._trace.level > level]
if tracers:
ans = max(tracers, key=lambda x: x._trace.level)
else:
break
trace = type(ans._trace)(ans._trace.master, cur_sublevel())
outs = map(trace.full_raise, outs)
outs, cur_todo = trace.post_process_map(primitive, outs, params)
todo.append(cur_todo)
yield outs, tuple(todo) # Ensure the aux output is immutable

def map_bind(primitive, f: lu.WrappedFun, *args, **params):
top_trace = find_top_trace(args)
level = trace_state.trace_stack.next_level(True) if top_trace is None else top_trace.level
params_tuple = tuple(params.items())
f, env_trace_todo = process_env_traces_map(f, primitive, level, params_tuple)
if top_trace is None:
with new_sublevel():
outs = primitive.impl(f, *args, **params)
else:
tracers = map(top_trace.full_raise, args)
outs = map(full_lower, top_trace.process_map(primitive, f, tracers, params))
return apply_todos(env_trace_todo(), outs)


# ------------------- Jaxpr printed representation -------------------

def check_jaxpr(jaxpr: Jaxpr):
Expand Down
3 changes: 3 additions & 0 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,9 @@ def todo(x):
return map(partial(JVPTracer, trace), primals, tangents)
return out, todo

process_map = process_call
post_process_map = post_process_call

def process_custom_jvp_call(self, _, __, f_jvp, tracers):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
primals_in = map(core.full_lower, primals_in)
Expand Down
30 changes: 18 additions & 12 deletions jax/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,21 @@ def process_primitive(self, primitive, tracers, params):
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
assert call_primitive.multiple_results
params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap'))
if call_primitive in pe.map_primitives:
return self.process_map(call_primitive, f, tracers, params)
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
if all(bdim is not_mapped for bdim in dims):
return call_primitive.bind(f, *vals, **params)
else:
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
if all(bdim is not_mapped for bdim in dims):
return call_primitive.bind(f, *vals, **params)
else:
f, dims_out = batch_subtrace(f, self.master, dims)
vals_out = call_primitive.bind(f, *vals, **params)
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
f, dims_out = batch_subtrace(f, self.master, dims)
vals_out = call_primitive.bind(f, *vals, **params)
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]

def post_process_call(self, call_primitive, out_tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
master = self.master
def todo(vals):
trace = BatchTrace(master, core.cur_sublevel())
return map(partial(BatchTracer, trace), vals, dims)
return vals, todo

def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
Expand All @@ -166,12 +171,13 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
dims_out = tuple(d + 1 if d is not not_mapped else d for d in dims_out())
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)]

def post_process_call(self, call_primitive, out_tracers, params):
def post_process_map(self, call_primitive, out_tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
master = self.master
def todo(x):
def todo(vals):
trace = BatchTrace(master, core.cur_sublevel())
return map(partial(BatchTracer, trace), x, dims)
return [BatchTracer(trace, v, d + 1 if d is not not_mapped else d)
for v, d in zip(vals, dims)]
return vals, todo

def process_custom_jvp_call(self, prim, fun, jvp, tracers):
Expand Down
5 changes: 0 additions & 5 deletions jax/interpreters/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ def process_primitive(self, primitive, tracers, params):
return PapplyTracer(self, name, size, val_out, axis_out)

def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
if call_primitive in pe.map_primitives:
return self.process_map(call_primitive, f, tracers, params)
names, vals, axes = unzip3((t.name, t.val, t.axis) for t in tracers)
if all(axis is not_sharded for axis in axes):
return call_primitive.bind(f, *vals, **params)
Expand All @@ -133,8 +131,5 @@ def todo(x):
return PapplyTracer(trace, name, size, x, axis)
return val, todo

def process_map(self, map_primitive, f :lu.WrappedFun, tracers, params):
raise NotImplementedError # TODO(mattjj,frostig)


papply_primitive_rules: Dict[core.Primitive, Callable] = {}
61 changes: 32 additions & 29 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,9 @@ def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
else:
name = wrap_name(name, 'pe')
params = dict(params, name=name)

if call_primitive in call_partial_eval_rules:
return call_partial_eval_rules[call_primitive](self, call_primitive, f, tracers, params)
if call_primitive in map_primitives:
return self.process_map(call_primitive, f, tracers, params)
in_pvs, in_consts = unzip2([t.pval for t in tracers])
fun, aux = partial_eval(f, self, in_pvs)
out_flat = call_primitive.bind(fun, *in_consts, **params)
Expand All @@ -190,7 +189,38 @@ def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
t.recipe = eqn
return out_tracers

def post_process_call(self, call_primitive, out_tracers, params):
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
out = out_pv_consts + consts
del consts, out_pv_consts
master = self.master
def todo(x):
n = len(jaxpr.outvars)
out_pv_consts, consts = x[:n], x[n:]
trace = JaxprTrace(master, core.cur_sublevel())
const_tracers = map(trace.new_instantiated_const, consts)
env_tracers = map(trace.full_raise, env)
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
new_params = dict(params, call_jaxpr=lifted_jaxpr)
# The `jaxpr` already contains the env_vars at start of invars
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers)),
out_tracers, call_primitive, new_params)
for t in out_tracers:
t.recipe = eqn
return out_tracers
return out, todo

def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
name = params.get('name', f.__name__)
if self.master.trace_type is StagingJaxprTrace:
tracers = map(self.instantiate_const_abstracted, tracers)
else:
name = wrap_name(name, 'pe')

params = dict(params, name=name)
in_pvs, in_consts = unzip2([t.pval for t in tracers])
reduced_pvs = [None if pv is None else _mapped_aval(pv) for pv in in_pvs]
fun, aux = partial_eval(f, self, reduced_pvs)
Expand All @@ -216,32 +246,6 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
t.recipe = eqn
return out_tracers

def post_process_call(self, call_primitive, out_tracers, params):
if call_primitive in map_primitives:
return self.post_process_map(call_primitive, out_tracers, params)
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
out = out_pv_consts + consts
del consts, out_pv_consts
master = self.master
def todo(x):
n = len(jaxpr.outvars)
out_pv_consts, consts = x[:n], x[n:]
trace = JaxprTrace(master, core.cur_sublevel())
const_tracers = map(trace.new_instantiated_const, consts)
env_tracers = map(trace.full_raise, env)
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
new_params = dict(params, call_jaxpr=lifted_jaxpr)
# The `jaxpr` already contains the env_vars at start of invars
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers)),
out_tracers, call_primitive, new_params)
for t in out_tracers:
t.recipe = eqn
return out_tracers
return out, todo

def post_process_map(self, map_primitive, out_tracers, params):
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
out_pvs_reduced, out_pv_consts = unzip2(t.pval for t in out_tracers)
Expand Down Expand Up @@ -306,7 +310,6 @@ def _unmapped_aval(size, aval):
else:
raise TypeError(aval)

map_primitives: Set[core.Primitive] = set()
custom_partial_eval_rules: Dict[core.Primitive, Callable] = {}
call_partial_eval_rules: Dict[core.Primitive, Callable] = {}

Expand Down
20 changes: 9 additions & 11 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def execute_replicated(compiled, backend, in_handler, out_handler, *args):
xla_pmap_p = core.Primitive('xla_pmap')
xla_pmap_p.call_primitive = True
xla_pmap_p.multiple_results = True
xla_pmap = partial(core.call_bind, xla_pmap_p)
xla_pmap = partial(core.map_bind, xla_pmap_p)
xla_pmap_p.def_custom_bind(xla_pmap)
xla_pmap_p.def_impl(xla_pmap_impl)

Expand Down Expand Up @@ -847,7 +847,6 @@ def _pmap_translation_rule(c, axis_env,

xla.call_translations[xla_pmap_p] = _pmap_translation_rule
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
pe.map_primitives.add(xla_pmap_p)

def _xla_shard(c, aval, axis_env, x):
if aval is core.abstract_unit:
Expand Down Expand Up @@ -1019,16 +1018,13 @@ def new_tracer(x, a):

def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
assert call_primitive.multiple_results
if call_primitive in pe.map_primitives:
return self.process_map(call_primitive, f, tracers, params)
vals, names = unzip2((t.val, t.axis_name) for t in tracers)
if all(name is not_mapped for name in names):
return call_primitive.bind(f, *vals, **params)
else:
vals, names = unzip2((t.val, t.axis_name) for t in tracers)
if all(name is not_mapped for name in names):
return call_primitive.bind(f, *vals, **params)
else:
f, names_out = split_axis_subtrace(f, self.master, names)
vals_out = call_primitive.bind(f, *vals, **params)
return [SplitAxisTracer(self, a, x) for a, x in zip(names_out(), vals_out)]
f, names_out = split_axis_subtrace(f, self.master, names)
vals_out = call_primitive.bind(f, *vals, **params)
return [SplitAxisTracer(self, a, x) for a, x in zip(names_out(), vals_out)]

def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
vals, names = unzip2((t.val, t.axis_name) for t in tracers)
Expand All @@ -1054,5 +1050,7 @@ def todo(x):
return SplitAxisTracer(trace, name, x)
return val, todo

post_process_map = post_process_call


split_axis_rules: Dict[core.Primitive, Callable] = {}
4 changes: 3 additions & 1 deletion tests/lax_scipy_sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,14 @@ def args_maker():
check_dtypes=True,
tol=3e-5)

# TODO(mattjj): I had to loosen the tolerance for complex64[7,7]
# with preconditioner=random
self._CheckAgainstNumpy(
partial(scipy_cg, M=M, maxiter=3),
partial(lax_cg, M=M, maxiter=3),
args_maker,
check_dtypes=True,
tol=1e-4)
tol=3e-3)

self._CheckAgainstNumpy(
np.linalg.solve,
Expand Down
21 changes: 21 additions & 0 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,27 @@ def f(args_list):
np.array([sum(vals)] * ndevices),
check_dtypes=True)

def testPostProcessMap(self):
# code from https://github.com/google/jax/issues/2787
def vv(x, y):
"""Vector-vector multiply"""
return np.dot(x, y)

def distributed_matrix_vector(x, y):
"""Matrix vector multiply. First batch it and then row by row"""
fv = lambda z: lax.map(lambda j: vv(j, y), z)
res = pmap(fv)(x.reshape((jax.device_count(), -1) + tuple(x.shape[1:])))
res = res.reshape(res.shape[0] * res.shape[1], *res.shape[2:])
return res

key = random.PRNGKey(1)
x = random.normal(key, (800, 50))
batched_mvm = vmap(lambda b: distributed_matrix_vector(x, b), in_axes=0)
y = random.normal(key, (10, 50, 1))
result = batched_mvm(y)
expected = np.einsum('ij,njk->nik', x, y)
self.assertAllClose(result, expected, check_dtypes=False, atol=1e-3, rtol=1e-3)


class PmapWithDevicesTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 063cf0f

Please sign in to comment.