Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle mapped_invars correctly in more places #2828

Merged
merged 4 commits into from
Apr 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,8 @@ def f_pmapped(*args, **kwargs):
axis_size=local_axis_size,
global_axis_size=axis_size,
devices=tuple(devices) if devices is not None else devices,
name=flat_fun.__name__)
name=flat_fun.__name__,
mapped_invars=(True,) * len(args))
return tree_unflatten(out_tree(), out)

namestr = "pmap({}, axis_name={})".format
Expand Down Expand Up @@ -1039,7 +1040,8 @@ def f_pmapped(*args, **kwargs):
reshaped_outs = pxla.xla_pmap(soft_mapped_fun, *reshaped_args, backend=backend,
axis_name=axis_name, axis_size=num_chunks,
global_axis_size=None, devices=None,
name=soft_mapped_fun.__name__)
name=soft_mapped_fun.__name__,
mapped_invars=(True,) * len(reshaped_args))
outs = [_reshape_merge(out) for out in reshaped_outs]
return tree_unflatten(out_tree(), outs)

Expand Down
10 changes: 5 additions & 5 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ def __repr__(self):
literalable_types: Set[type] = set()

class Primitive(object):
multiple_results = False # override for multi-output primitives
call_primitive = False # override for higher-order primitives that are
# processed in final style.
multiple_results = False # set for multi-output primitives
call_primitive = False # set for call primitives processed in final style
map_primitive = False # set for map primitives processed in final style

def __init__(self, name):
self.name = name
Expand Down Expand Up @@ -236,7 +236,7 @@ def extract_call_jaxpr(primitive, params):
Returns the subjaxpr and the params without the "call_jaxpr" value. If this is
not a call primitive then returns (None, params).
"""
if not primitive.call_primitive:
if not (primitive.call_primitive or primitive.map_primitive):
return (None, params)
else:
assert "call_jaxpr" in params
Expand Down Expand Up @@ -1046,7 +1046,7 @@ def write_env(env: Set[Var], v: Var):
map(write, jaxpr.constvars)
map(write, jaxpr.invars)
for eqn in jaxpr.eqns:
if eqn.primitive.call_primitive:
if eqn.primitive.call_primitive or eqn.map_primitive:
if "call_jaxpr" not in eqn.params:
raise Exception("Call primitive {} should have a 'call_jaxpr' parameter"
.format(eqn.primitive))
Expand Down
54 changes: 40 additions & 14 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,23 +178,24 @@ def is_linear(var):

linear_eqns = []
for eqn in jaxpr.eqns:
if not eqn.primitive.call_primitive:
prim = eqn.primitive
if not (prim.call_primitive or prim.map_primitive):
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
else:
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
if eqn.primitive.multiple_results:
ans = prim.bind(*in_vals, **eqn.params)
if prim.multiple_results:
map(write_primal, eqn.outvars, ans)
else:
write_primal(eqn.outvars[0], ans)
else:
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
call_jaxpr, params = core.extract_call_jaxpr(prim, eqn.params)
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
if any(not is_linear(v) for v in eqn.invars):
# FIXME: Some invars correspond to tangents
ans = _eval_subjaxpr_primals(eqn.primitive, call_jaxpr,
ans = _eval_subjaxpr_primals(prim, call_jaxpr,
map(read_primal, eqn.invars), params)
map(write_primal, eqn.outvars, ans)

Expand All @@ -216,7 +217,7 @@ def is_linear(var):
cts_in = map(read_cotangent, eqn.outvars)
else:
cts_in, = map(read_cotangent, eqn.outvars)
if eqn.primitive.call_primitive:
if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
cts_out = get_primitive_transpose(eqn.primitive)(
params, call_jaxpr, invals, cts_in)
Expand All @@ -236,7 +237,14 @@ def _eval_subjaxpr_primals(prim, jaxpr, in_vals, params):
all_args, in_tree_def = tree_flatten((in_vals,))
fun = lu.hashable_partial(lu.wrap_init(_eval_primals), jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = prim.bind(fun, *all_args, **params)
assert prim.map_primitive ^ prim.call_primitive
if prim.map_primitive:
new_mapped_invars = [m for m, x in zip(params['mapped_invars'], in_vals)
if not is_undefined_primal(x)]
new_params = dict(params, mapped_invars=tuple(new_mapped_invars))
out_flat = prim.bind(fun, *all_args, **new_params)
else:
out_flat = prim.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)

def _eval_primals(jaxpr, args):
Expand All @@ -262,7 +270,7 @@ def is_linear(var):
assert not jaxpr.constvars
map(write_primal, jaxpr.invars, args)
for eqn in jaxpr.eqns:
if not eqn.primitive.call_primitive:
if not (eqn.primitive.call_primitive or eqn.primitive.map_primitive):
if not any(is_linear(v) for v in eqn.invars):
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
Expand Down Expand Up @@ -327,14 +335,13 @@ def process_primitive(self, primitive, tracers, params):

def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
assert call_primitive.multiple_results
primals = [t.primal for t in tracers]
tangents = [t.tangent for t in tracers]
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
nonzero_tangents, in_tree_def = tree_flatten(tangents)
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
len(primals), in_tree_def)
name = params.get('name', f.__name__)
params = dict(params, name=wrap_name(name, 'jvp'))
result = call_primitive.bind(f_jvp, *(primals + nonzero_tangents), **params)
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **params)
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]

Expand All @@ -350,7 +357,22 @@ def todo(x):
return map(partial(JVPTracer, trace), primals, tangents)
return out, todo

process_map = process_call
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
# only differs from process_call in that it must update mapped_invars
# TODO de-duplicate code
assert map_primitive.multiple_results
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
nonzero_tangents, in_tree_def = tree_flatten(tangents)
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
len(primals), in_tree_def)
new_name = wrap_name(params.get('name', f.__name__), 'jvp')
new_mapped_invars = (*params['mapped_invars'],
*[m for m, t in zip(params['mapped_invars'], tangents)
if t is not zero])
new_params = dict(params, name=new_name, mapped_invars=new_mapped_invars)
result = map_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
post_process_map = post_process_call

def process_custom_jvp_call(self, _, __, f_jvp, tracers):
Expand Down Expand Up @@ -540,8 +562,12 @@ def map_transpose(primitive, params, call_jaxpr, args, ct):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
params = dict(params, name=wrap_name(params['name'], 'transpose'))
out_flat = primitive.bind(fun, *all_args, **params)
new_mapped_invars = (*[m for m, x in zip(params['mapped_invars'], args)
if not is_undefined_primal(x)],
*[True for x in ct if x is not zero])
new_params = dict(params, name=wrap_name(params['name'], 'transpose'),
mapped_invars=new_mapped_invars)
out_flat = primitive.bind(fun, *all_args, **new_params)
arg_cts = tree_unflatten(out_tree(), out_flat)

mapped_invars = params['mapped_invars'] # True for each mapped invar
Expand Down
22 changes: 14 additions & 8 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):

params = dict(params, name=name)
in_pvs, in_consts = unzip2([t.pval for t in tracers])
reduced_pvs = [None if pv is None else _mapped_aval(pv) for pv in in_pvs]
reduced_pvs = [None if pv is None else
_mapped_aval(params['axis_size'], pv) if m else pv
for pv, m in zip(in_pvs, params['mapped_invars'])]
fun, aux = partial_eval(f, self, reduced_pvs)
out_flat = map_primitive.bind(fun, *in_consts, **params)
out_pvs_reduced, jaxpr, env = aux()
Expand All @@ -237,10 +239,12 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
# The `jaxpr` already contains the env_vars at start of invars
new_params = dict(params,
mapped_invars=tuple([True] * len(const_tracers) +
[False] * len(env_tracers) +
[True] * len(tracers)),
mapped_invars=((True,) * len(const_tracers) +
(False,) * len(env_tracers) +
params['mapped_invars']),
call_jaxpr=lifted_jaxpr)
assert (len(new_params['mapped_invars'])
== len(const_tracers) + len(env_tracers) + len(tracers))
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers, tracers)),
out_tracers, map_primitive, new_params)
for t in out_tracers:
Expand Down Expand Up @@ -294,11 +298,12 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
class StagingJaxprTrace(JaxprTrace):
pass

def _mapped_aval(aval):
def _mapped_aval(size, aval):
if aval is core.abstract_unit:
return aval
elif isinstance(aval, ShapedArray):
# might be raising abstraction level from Concrete here
assert aval.shape[0] == size
return ShapedArray(aval.shape[1:], aval.dtype)
else:
raise TypeError(aval)
Expand Down Expand Up @@ -475,10 +480,11 @@ def new_eqn_recipe(invars, outvars, primitive, params):
primitive: the primitive.
params: the primitive params
"""
if primitive.call_primitive:
# TODO(necula): move these checks to core.check_jaxpr, and call it
# in more places.
# TODO(necula): move these checks to core.check_jaxpr, and call in more places
if primitive.call_primitive or primitive.map_primitive:
assert "call_jaxpr" in params
if primitive.map_primitive:
assert "mapped_invars" in params
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
params)

Expand Down
4 changes: 2 additions & 2 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py

def xla_pmap_impl(fun: lu.WrappedFun, *args, backend, axis_name, axis_size, global_axis_size,
devices, name, mapped_invars=None):
devices, name, mapped_invars):
abstract_args = map(xla.abstractify, args)
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size,
global_axis_size, devices, name, *abstract_args)
Expand Down Expand Up @@ -820,7 +820,7 @@ def execute_replicated(compiled, backend, in_handler, out_handler, *args):


xla_pmap_p = core.Primitive('xla_pmap')
xla_pmap_p.call_primitive = True
xla_pmap_p.map_primitive = True
xla_pmap_p.multiple_results = True
xla_pmap = partial(core.map_bind, xla_pmap_p)
xla_pmap_p.def_custom_bind(xla_pmap)
Expand Down
36 changes: 34 additions & 2 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,8 +902,8 @@ def testSoftPmapDevicePersistence(self):
x = x.reshape(2 * device_count, 2, 2, 3) # axis merge of the wrong size
self.assertIsInstance(x, xla.DeviceArray) # should have forced collection

@jtu.skip_on_devices("gpu")
def DISABLED_testSoftPmapAllToAll(self):
def testSoftPmapAllToAll(self):
raise SkipTest("the underlying code here is broken") # TODO(mattjj)
n = 4 * xla_bridge.device_count()
def f(x):
return lax.all_to_all(x, 'i', 0, 0)
Expand Down Expand Up @@ -1050,6 +1050,38 @@ def f(key):
keys = random.split(random.PRNGKey(0), n)
jax.pmap(jax.remat(f), axis_name='i')(keys)

def testPmapMapVmapCombinations(self):
# https://github.com/google/jax/issues/2822
def vv(x, y):
"""Vector-vector multiply"""
return np.dot(x, y)

def matrix_vector(x, y, parallel=True):
"""Matrix vector multiply. First batch it and then row by row"""
fv = lambda z: lax.map(lambda j: vv(j, y), z)
if parallel:
# split leading axis in two
new_x = x.reshape((jax.device_count(), -1, *x.shape[1:]))
# apply map
new_res = pmap(fv)(new_x)
# reshape back out
res = new_res.reshape(x.shape[0], *new_res.shape[2:])
else:
res = fv(x)
return res

x = random.normal(random.PRNGKey(1), (80, 5))
y = random.normal(random.PRNGKey(1), (10, 5))

result1 = vmap(lambda b: matrix_vector(x, b, True))(y) # vmap + pmap
result2 = lax.map(lambda b: matrix_vector(x, b, False), y) # map + map
result3 = lax.map(lambda b: matrix_vector(x, b, True), y) # map + pmap
result4 = np.stack([matrix_vector(x, b, False) for b in y]) # none + map

self.assertAllClose(result1, result2, check_dtypes=False, atol=1e-3, rtol=1e-3)
self.assertAllClose(result1, result3, check_dtypes=False, atol=1e-3, rtol=1e-3)
self.assertAllClose(result1, result4, check_dtypes=False, atol=1e-3, rtol=1e-3)


class PmapWithDevicesTest(jtu.JaxTestCase):

Expand Down