Skip to content

Commit

Permalink
handle mapped_invars correctly in more places
Browse files Browse the repository at this point in the history
fixes #2822

We didn't handle mapped_invars correctly in all places in #1959. In
particular, in #1959 we:
  1. assumed the `mapped_invars` parameter of xla_pmap_p was only
     populated after partial_eval and set to None otherwise (i.e.
     staging out for a jit or a control flow primitive),
  2. didn't update it correctly in JVPTrace.process_map (which adds new
     inputs corresponding to nonzero tangents, and hence `mapped_invars`
     must be grown),
  3. didn't update it correctly in JaxprTrace.process_map (which adds
     inputs to the staged-out version of the primitive,
  4. didn't forward it correctly in JaxprTrace.process_map anyway (we
     were setting it to all-true for the staged out eqn),
  5. removed the leading axes of all pvs in JaxprTrace.process_map
     regardless of whether the corresponding entry of `mapped_invars`
     was True or False.

The reason we didn't notice 2 and 3 was that they only arise when doing
control flow (e.g. scan or remat) of pmap involving closed-over tracers
(apparently a rare case), since that's the case where we first form a
jaxpr (populating `mapped_invars`) and then later have to apply
transformations like AD and further partial eval (thus engaging
JVPTrace.process_map and JaxprTrace.process_map with a populated
`mapped_invars` parameter). It worked in other cases, e.g. when the pmap
was not inside control flow or a remat, because in those cases we left
`mapped_invars` set to None, indicating all-true of any length (so it
didn't matter if we add inputs).

This commit fixes those issues by
  1. making `mapped_invars` non-optional (even though the default value
     of None is convenient as long as it lasts, it's not worth the
     complexity of handling the two None-or-populated cases everywhere
     downstream),
  2. handling `mapped_invars` correctly in
    * JaxprTrace.process_map
    * JVPTrace.process_map
    * ad.map_transpose (since having symbolic-zero cotangents
      effectively prunes inputs, and having undefined-primal args also
      prunes inputs)
    * ad._eval_subjaxpr_primals (since having undefined-primal args
      prunes inputs)
  3. making the separate cases of calls and maps handled more explicity
     by adding a new Primitive.map_primitive boolean attribute
     (analogous to Primitive.call_primitive).

This is begging for a more coherent cleanup. For example, we reuse the
same Primitive class but tag it with `call_primitive` or `map_primitive`
(only one of which can be True); we should instead just have a separate
Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or
`map_primitive=True` implies things about what `params` must be present
(`call_jaxpr` and `mapped_invars`). I plan to follow up with those
cleanups, but I wanted to get something working first.
  • Loading branch information
mattjj committed Apr 24, 2020
1 parent 6ad2908 commit 7c7a0ed
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 33 deletions.
6 changes: 4 additions & 2 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,8 @@ def f_pmapped(*args, **kwargs):
axis_size=local_axis_size,
global_axis_size=axis_size,
devices=tuple(devices) if devices is not None else devices,
name=flat_fun.__name__)
name=flat_fun.__name__,
mapped_invars=(True,) * len(args))
return tree_unflatten(out_tree(), out)

namestr = "pmap({}, axis_name={})".format
Expand Down Expand Up @@ -1039,7 +1040,8 @@ def f_pmapped(*args, **kwargs):
reshaped_outs = pxla.xla_pmap(soft_mapped_fun, *reshaped_args, backend=backend,
axis_name=axis_name, axis_size=num_chunks,
global_axis_size=None, devices=None,
name=soft_mapped_fun.__name__)
name=soft_mapped_fun.__name__,
mapped_invars=(True,) * len(reshaped_args))
outs = [_reshape_merge(out) for out in reshaped_outs]
return tree_unflatten(out_tree(), outs)

Expand Down
13 changes: 8 additions & 5 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ def __repr__(self):
literalable_types: Set[type] = set()

class Primitive(object):
multiple_results = False # override for multi-output primitives
call_primitive = False # override for higher-order primitives that are
# processed in final style.
multiple_results = False # set for multi-output primitives
call_primitive = False # set for call primitives processed in final style
map_primitive = False # set for map primitives processed in final style

def __init__(self, name):
self.name = name
Expand Down Expand Up @@ -236,7 +236,7 @@ def extract_call_jaxpr(primitive, params):
Returns the subjaxpr and the params without the "call_jaxpr" value. If this is
not a call primitive then returns (None, params).
"""
if not primitive.call_primitive:
if not (primitive.call_primitive or primitive.map_primitive):
return (None, params)
else:
assert "call_jaxpr" in params
Expand Down Expand Up @@ -990,6 +990,9 @@ def process_env_traces(post_processor: str, primitive: Primitive,

def _call_bind(processor: str, post_processor: str, primitive: Primitive,
f: lu.WrappedFun, *args, **params):
# TODO add a check like this, clean up to get rid of strings
if processor == 'process_map':
assert len(args) == len(params['mapped_invars'])
top_trace = find_top_trace(args)
level = trace_state.trace_stack.next_level(True) if top_trace is None else top_trace.level
params_tuple = tuple(params.items())
Expand Down Expand Up @@ -1046,7 +1049,7 @@ def write_env(env: Set[Var], v: Var):
map(write, jaxpr.constvars)
map(write, jaxpr.invars)
for eqn in jaxpr.eqns:
if eqn.primitive.call_primitive:
if eqn.primitive.call_primitive or eqn.map_primitive:
if "call_jaxpr" not in eqn.params:
raise Exception("Call primitive {} should have a 'call_jaxpr' parameter"
.format(eqn.primitive))
Expand Down
54 changes: 40 additions & 14 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,23 +178,24 @@ def is_linear(var):

linear_eqns = []
for eqn in jaxpr.eqns:
if not eqn.primitive.call_primitive:
prim = eqn.primitive
if not (prim.call_primitive or prim.map_primitive):
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
else:
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
if eqn.primitive.multiple_results:
ans = prim.bind(*in_vals, **eqn.params)
if prim.multiple_results:
map(write_primal, eqn.outvars, ans)
else:
write_primal(eqn.outvars[0], ans)
else:
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
call_jaxpr, params = core.extract_call_jaxpr(prim, eqn.params)
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
if any(not is_linear(v) for v in eqn.invars):
# FIXME: Some invars correspond to tangents
ans = _eval_subjaxpr_primals(eqn.primitive, call_jaxpr,
ans = _eval_subjaxpr_primals(prim, call_jaxpr,
map(read_primal, eqn.invars), params)
map(write_primal, eqn.outvars, ans)

Expand All @@ -216,7 +217,7 @@ def is_linear(var):
cts_in = map(read_cotangent, eqn.outvars)
else:
cts_in, = map(read_cotangent, eqn.outvars)
if eqn.primitive.call_primitive:
if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
cts_out = get_primitive_transpose(eqn.primitive)(
params, call_jaxpr, invals, cts_in)
Expand All @@ -236,7 +237,14 @@ def _eval_subjaxpr_primals(prim, jaxpr, in_vals, params):
all_args, in_tree_def = tree_flatten((in_vals,))
fun = lu.hashable_partial(lu.wrap_init(_eval_primals), jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = prim.bind(fun, *all_args, **params)
assert prim.map_primitive ^ prim.call_primitive
if prim.map_primitive:
new_mapped_invars = [m for m, x in zip(params['mapped_invars'], in_vals)
if not is_undefined_primal(x)]
new_params = dict(params, mapped_invars=tuple(new_mapped_invars))
out_flat = prim.bind(fun, *all_args, **new_params)
else:
out_flat = prim.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)

def _eval_primals(jaxpr, args):
Expand All @@ -262,7 +270,7 @@ def is_linear(var):
assert not jaxpr.constvars
map(write_primal, jaxpr.invars, args)
for eqn in jaxpr.eqns:
if not eqn.primitive.call_primitive:
if not (eqn.primitive.call_primitive or eqn.primitive.map_primitive):
if not any(is_linear(v) for v in eqn.invars):
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
Expand Down Expand Up @@ -327,14 +335,13 @@ def process_primitive(self, primitive, tracers, params):

def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
assert call_primitive.multiple_results
primals = [t.primal for t in tracers]
tangents = [t.tangent for t in tracers]
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
nonzero_tangents, in_tree_def = tree_flatten(tangents)
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
len(primals), in_tree_def)
name = params.get('name', f.__name__)
params = dict(params, name=wrap_name(name, 'jvp'))
result = call_primitive.bind(f_jvp, *(primals + nonzero_tangents), **params)
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **params)
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]

Expand All @@ -350,7 +357,22 @@ def todo(x):
return map(partial(JVPTracer, trace), primals, tangents)
return out, todo

process_map = process_call
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
# only differs from process_call in that it must update mapped_invars
# TODO de-duplicate code
assert map_primitive.multiple_results
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
nonzero_tangents, in_tree_def = tree_flatten(tangents)
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
len(primals), in_tree_def)
new_name = wrap_name(params.get('name', f.__name__), 'jvp')
new_mapped_invars = (*params['mapped_invars'],
*[m for m, t in zip(params['mapped_invars'], tangents)
if t is not zero])
new_params = dict(params, name=new_name, mapped_invars=new_mapped_invars)
result = map_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
post_process_map = post_process_call

def process_custom_jvp_call(self, _, __, f_jvp, tracers):
Expand Down Expand Up @@ -540,8 +562,12 @@ def map_transpose(primitive, params, call_jaxpr, args, ct):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
params = dict(params, name=wrap_name(params['name'], 'transpose'))
out_flat = primitive.bind(fun, *all_args, **params)
new_mapped_invars = (*[m for m, x in zip(params['mapped_invars'], args)
if not is_undefined_primal(x)],
*[True for x in ct if x is not zero])
new_params = dict(params, name=wrap_name(params['name'], 'transpose'),
mapped_invars=new_mapped_invars)
out_flat = primitive.bind(fun, *all_args, **new_params)
arg_cts = tree_unflatten(out_tree(), out_flat)

mapped_invars = params['mapped_invars'] # True for each mapped invar
Expand Down
22 changes: 14 additions & 8 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):

params = dict(params, name=name)
in_pvs, in_consts = unzip2([t.pval for t in tracers])
reduced_pvs = [None if pv is None else _mapped_aval(pv) for pv in in_pvs]
reduced_pvs = [None if pv is None else
_mapped_aval(params['axis_size'], pv) if m else pv
for pv, m in zip(in_pvs, params['mapped_invars'])]
fun, aux = partial_eval(f, self, reduced_pvs)
out_flat = map_primitive.bind(fun, *in_consts, **params)
out_pvs_reduced, jaxpr, env = aux()
Expand All @@ -236,10 +238,12 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
# The `jaxpr` already contains the env_vars at start of invars
new_params = dict(params,
mapped_invars=tuple([True] * len(const_tracers) +
[False] * len(env_tracers) +
[True] * len(tracers)),
mapped_invars=((True,) * len(const_tracers) +
(False,) * len(env_tracers) +
params['mapped_invars']),
call_jaxpr=lifted_jaxpr)
assert (len(new_params['mapped_invars'])
== len(const_tracers) + len(env_tracers) + len(tracers))
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers, tracers)),
out_tracers, map_primitive, new_params)
for t in out_tracers:
Expand Down Expand Up @@ -293,11 +297,12 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
class StagingJaxprTrace(JaxprTrace):
pass

def _mapped_aval(aval):
def _mapped_aval(size, aval):
if aval is core.abstract_unit:
return aval
elif isinstance(aval, ShapedArray):
# might be raising abstraction level from Concrete here
assert aval.shape[0] == size
return ShapedArray(aval.shape[1:], aval.dtype)
else:
raise TypeError(aval)
Expand Down Expand Up @@ -473,10 +478,11 @@ def new_eqn_recipe(invars, outvars, primitive, params):
primitive: the primitive.
params: the primitive params
"""
if primitive.call_primitive:
# TODO(necula): move these checks to core.check_jaxpr, and call it
# in more places.
# TODO(necula): move these checks to core.check_jaxpr, and call in more places
if primitive.call_primitive or primitive.map_primitive:
assert "call_jaxpr" in params
if primitive.map_primitive:
assert "mapped_invars" in params
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
params)

Expand Down
4 changes: 2 additions & 2 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py

def xla_pmap_impl(fun: lu.WrappedFun, *args, backend, axis_name, axis_size, global_axis_size,
devices, name, mapped_invars=None):
devices, name, mapped_invars):
abstract_args = map(xla.abstractify, args)
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size,
global_axis_size, devices, name, *abstract_args)
Expand Down Expand Up @@ -820,7 +820,7 @@ def execute_replicated(compiled, backend, in_handler, out_handler, *args):


xla_pmap_p = core.Primitive('xla_pmap')
xla_pmap_p.call_primitive = True
xla_pmap_p.map_primitive = True
xla_pmap_p.multiple_results = True
xla_pmap = partial(core.map_bind, xla_pmap_p)
xla_pmap_p.def_custom_bind(xla_pmap)
Expand Down
4 changes: 2 additions & 2 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,8 +902,8 @@ def testSoftPmapDevicePersistence(self):
x = x.reshape(2 * device_count, 2, 2, 3) # axis merge of the wrong size
self.assertIsInstance(x, xla.DeviceArray) # should have forced collection

@jtu.skip_on_devices("gpu")
def DISABLED_testSoftPmapAllToAll(self):
@jtu.skip_on_devices("gpu", "cpu")
def testSoftPmapAllToAll(self):
n = 4 * xla_bridge.device_count()
def f(x):
return lax.all_to_all(x, 'i', 0, 0)
Expand Down

0 comments on commit 7c7a0ed

Please sign in to comment.