Skip to content

Commit

Permalink
handle mapped_invars correctly in more places (#2828)
Browse files Browse the repository at this point in the history
fixes #2822

We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
  1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
  2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
  3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
  4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
  5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.

The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).

This commit fixes those issues by
  1. making `mapped_invars` non-optional,
  2. handling `mapped_invars` correctly in
    * JaxprTrace.process_map
    * JVPTrace.process_map
    * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
    * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
  3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.

This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
  • Loading branch information
mattjj authored Apr 25, 2020
1 parent 8f90245 commit 89e3840
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 33 deletions.
6 changes: 4 additions & 2 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,8 @@ def f_pmapped(*args, **kwargs):
axis_size=local_axis_size,
global_axis_size=axis_size,
devices=tuple(devices) if devices is not None else devices,
name=flat_fun.__name__)
name=flat_fun.__name__,
mapped_invars=(True,) * len(args))
return tree_unflatten(out_tree(), out)

namestr = "pmap({}, axis_name={})".format
Expand Down Expand Up @@ -1039,7 +1040,8 @@ def f_pmapped(*args, **kwargs):
reshaped_outs = pxla.xla_pmap(soft_mapped_fun, *reshaped_args, backend=backend,
axis_name=axis_name, axis_size=num_chunks,
global_axis_size=None, devices=None,
name=soft_mapped_fun.__name__)
name=soft_mapped_fun.__name__,
mapped_invars=(True,) * len(reshaped_args))
outs = [_reshape_merge(out) for out in reshaped_outs]
return tree_unflatten(out_tree(), outs)

Expand Down
10 changes: 5 additions & 5 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ def __repr__(self):
literalable_types: Set[type] = set()

class Primitive(object):
multiple_results = False # override for multi-output primitives
call_primitive = False # override for higher-order primitives that are
# processed in final style.
multiple_results = False # set for multi-output primitives
call_primitive = False # set for call primitives processed in final style
map_primitive = False # set for map primitives processed in final style

def __init__(self, name):
self.name = name
Expand Down Expand Up @@ -236,7 +236,7 @@ def extract_call_jaxpr(primitive, params):
Returns the subjaxpr and the params without the "call_jaxpr" value. If this is
not a call primitive then returns (None, params).
"""
if not primitive.call_primitive:
if not (primitive.call_primitive or primitive.map_primitive):
return (None, params)
else:
assert "call_jaxpr" in params
Expand Down Expand Up @@ -1046,7 +1046,7 @@ def write_env(env: Set[Var], v: Var):
map(write, jaxpr.constvars)
map(write, jaxpr.invars)
for eqn in jaxpr.eqns:
if eqn.primitive.call_primitive:
if eqn.primitive.call_primitive or eqn.map_primitive:
if "call_jaxpr" not in eqn.params:
raise Exception("Call primitive {} should have a 'call_jaxpr' parameter"
.format(eqn.primitive))
Expand Down
54 changes: 40 additions & 14 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,23 +178,24 @@ def is_linear(var):

linear_eqns = []
for eqn in jaxpr.eqns:
if not eqn.primitive.call_primitive:
prim = eqn.primitive
if not (prim.call_primitive or prim.map_primitive):
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
else:
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
if eqn.primitive.multiple_results:
ans = prim.bind(*in_vals, **eqn.params)
if prim.multiple_results:
map(write_primal, eqn.outvars, ans)
else:
write_primal(eqn.outvars[0], ans)
else:
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
call_jaxpr, params = core.extract_call_jaxpr(prim, eqn.params)
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
if any(not is_linear(v) for v in eqn.invars):
# FIXME: Some invars correspond to tangents
ans = _eval_subjaxpr_primals(eqn.primitive, call_jaxpr,
ans = _eval_subjaxpr_primals(prim, call_jaxpr,
map(read_primal, eqn.invars), params)
map(write_primal, eqn.outvars, ans)

Expand All @@ -216,7 +217,7 @@ def is_linear(var):
cts_in = map(read_cotangent, eqn.outvars)
else:
cts_in, = map(read_cotangent, eqn.outvars)
if eqn.primitive.call_primitive:
if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
cts_out = get_primitive_transpose(eqn.primitive)(
params, call_jaxpr, invals, cts_in)
Expand All @@ -236,7 +237,14 @@ def _eval_subjaxpr_primals(prim, jaxpr, in_vals, params):
all_args, in_tree_def = tree_flatten((in_vals,))
fun = lu.hashable_partial(lu.wrap_init(_eval_primals), jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = prim.bind(fun, *all_args, **params)
assert prim.map_primitive ^ prim.call_primitive
if prim.map_primitive:
new_mapped_invars = [m for m, x in zip(params['mapped_invars'], in_vals)
if not is_undefined_primal(x)]
new_params = dict(params, mapped_invars=tuple(new_mapped_invars))
out_flat = prim.bind(fun, *all_args, **new_params)
else:
out_flat = prim.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)

def _eval_primals(jaxpr, args):
Expand All @@ -262,7 +270,7 @@ def is_linear(var):
assert not jaxpr.constvars
map(write_primal, jaxpr.invars, args)
for eqn in jaxpr.eqns:
if not eqn.primitive.call_primitive:
if not (eqn.primitive.call_primitive or eqn.primitive.map_primitive):
if not any(is_linear(v) for v in eqn.invars):
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
Expand Down Expand Up @@ -327,14 +335,13 @@ def process_primitive(self, primitive, tracers, params):

def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
assert call_primitive.multiple_results
primals = [t.primal for t in tracers]
tangents = [t.tangent for t in tracers]
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
nonzero_tangents, in_tree_def = tree_flatten(tangents)
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
len(primals), in_tree_def)
name = params.get('name', f.__name__)
params = dict(params, name=wrap_name(name, 'jvp'))
result = call_primitive.bind(f_jvp, *(primals + nonzero_tangents), **params)
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **params)
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]

Expand All @@ -350,7 +357,22 @@ def todo(x):
return map(partial(JVPTracer, trace), primals, tangents)
return out, todo

process_map = process_call
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
# only differs from process_call in that it must update mapped_invars
# TODO de-duplicate code
assert map_primitive.multiple_results
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
nonzero_tangents, in_tree_def = tree_flatten(tangents)
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
len(primals), in_tree_def)
new_name = wrap_name(params.get('name', f.__name__), 'jvp')
new_mapped_invars = (*params['mapped_invars'],
*[m for m, t in zip(params['mapped_invars'], tangents)
if t is not zero])
new_params = dict(params, name=new_name, mapped_invars=new_mapped_invars)
result = map_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
post_process_map = post_process_call

def process_custom_jvp_call(self, _, __, f_jvp, tracers):
Expand Down Expand Up @@ -540,8 +562,12 @@ def map_transpose(primitive, params, call_jaxpr, args, ct):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
params = dict(params, name=wrap_name(params['name'], 'transpose'))
out_flat = primitive.bind(fun, *all_args, **params)
new_mapped_invars = (*[m for m, x in zip(params['mapped_invars'], args)
if not is_undefined_primal(x)],
*[True for x in ct if x is not zero])
new_params = dict(params, name=wrap_name(params['name'], 'transpose'),
mapped_invars=new_mapped_invars)
out_flat = primitive.bind(fun, *all_args, **new_params)
arg_cts = tree_unflatten(out_tree(), out_flat)

mapped_invars = params['mapped_invars'] # True for each mapped invar
Expand Down
22 changes: 14 additions & 8 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):

params = dict(params, name=name)
in_pvs, in_consts = unzip2([t.pval for t in tracers])
reduced_pvs = [None if pv is None else _mapped_aval(pv) for pv in in_pvs]
reduced_pvs = [None if pv is None else
_mapped_aval(params['axis_size'], pv) if m else pv
for pv, m in zip(in_pvs, params['mapped_invars'])]
fun, aux = partial_eval(f, self, reduced_pvs)
out_flat = map_primitive.bind(fun, *in_consts, **params)
out_pvs_reduced, jaxpr, env = aux()
Expand All @@ -237,10 +239,12 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
# The `jaxpr` already contains the env_vars at start of invars
new_params = dict(params,
mapped_invars=tuple([True] * len(const_tracers) +
[False] * len(env_tracers) +
[True] * len(tracers)),
mapped_invars=((True,) * len(const_tracers) +
(False,) * len(env_tracers) +
params['mapped_invars']),
call_jaxpr=lifted_jaxpr)
assert (len(new_params['mapped_invars'])
== len(const_tracers) + len(env_tracers) + len(tracers))
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers, tracers)),
out_tracers, map_primitive, new_params)
for t in out_tracers:
Expand Down Expand Up @@ -294,11 +298,12 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
class StagingJaxprTrace(JaxprTrace):
pass

def _mapped_aval(aval):
def _mapped_aval(size, aval):
if aval is core.abstract_unit:
return aval
elif isinstance(aval, ShapedArray):
# might be raising abstraction level from Concrete here
assert aval.shape[0] == size
return ShapedArray(aval.shape[1:], aval.dtype)
else:
raise TypeError(aval)
Expand Down Expand Up @@ -475,10 +480,11 @@ def new_eqn_recipe(invars, outvars, primitive, params):
primitive: the primitive.
params: the primitive params
"""
if primitive.call_primitive:
# TODO(necula): move these checks to core.check_jaxpr, and call it
# in more places.
# TODO(necula): move these checks to core.check_jaxpr, and call in more places
if primitive.call_primitive or primitive.map_primitive:
assert "call_jaxpr" in params
if primitive.map_primitive:
assert "mapped_invars" in params
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
params)

Expand Down
4 changes: 2 additions & 2 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py

def xla_pmap_impl(fun: lu.WrappedFun, *args, backend, axis_name, axis_size, global_axis_size,
devices, name, mapped_invars=None):
devices, name, mapped_invars):
abstract_args = map(xla.abstractify, args)
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size,
global_axis_size, devices, name, *abstract_args)
Expand Down Expand Up @@ -820,7 +820,7 @@ def execute_replicated(compiled, backend, in_handler, out_handler, *args):


xla_pmap_p = core.Primitive('xla_pmap')
xla_pmap_p.call_primitive = True
xla_pmap_p.map_primitive = True
xla_pmap_p.multiple_results = True
xla_pmap = partial(core.map_bind, xla_pmap_p)
xla_pmap_p.def_custom_bind(xla_pmap)
Expand Down
36 changes: 34 additions & 2 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,8 +902,8 @@ def testSoftPmapDevicePersistence(self):
x = x.reshape(2 * device_count, 2, 2, 3) # axis merge of the wrong size
self.assertIsInstance(x, xla.DeviceArray) # should have forced collection

@jtu.skip_on_devices("gpu")
def DISABLED_testSoftPmapAllToAll(self):
def testSoftPmapAllToAll(self):
raise SkipTest("the underlying code here is broken") # TODO(mattjj)
n = 4 * xla_bridge.device_count()
def f(x):
return lax.all_to_all(x, 'i', 0, 0)
Expand Down Expand Up @@ -1050,6 +1050,38 @@ def f(key):
keys = random.split(random.PRNGKey(0), n)
jax.pmap(jax.remat(f), axis_name='i')(keys)

def testPmapMapVmapCombinations(self):
# https://github.com/google/jax/issues/2822
def vv(x, y):
"""Vector-vector multiply"""
return np.dot(x, y)

def matrix_vector(x, y, parallel=True):
"""Matrix vector multiply. First batch it and then row by row"""
fv = lambda z: lax.map(lambda j: vv(j, y), z)
if parallel:
# split leading axis in two
new_x = x.reshape((jax.device_count(), -1, *x.shape[1:]))
# apply map
new_res = pmap(fv)(new_x)
# reshape back out
res = new_res.reshape(x.shape[0], *new_res.shape[2:])
else:
res = fv(x)
return res

x = random.normal(random.PRNGKey(1), (80, 5))
y = random.normal(random.PRNGKey(1), (10, 5))

result1 = vmap(lambda b: matrix_vector(x, b, True))(y) # vmap + pmap
result2 = lax.map(lambda b: matrix_vector(x, b, False), y) # map + map
result3 = lax.map(lambda b: matrix_vector(x, b, True), y) # map + pmap
result4 = np.stack([matrix_vector(x, b, False) for b in y]) # none + map

self.assertAllClose(result1, result2, check_dtypes=False, atol=1e-3, rtol=1e-3)
self.assertAllClose(result1, result3, check_dtypes=False, atol=1e-3, rtol=1e-3)
self.assertAllClose(result1, result4, check_dtypes=False, atol=1e-3, rtol=1e-3)


class PmapWithDevicesTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 89e3840

Please sign in to comment.