Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement lax.map #1118

Merged
merged 2 commits into from
Aug 5, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/jax.lax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ Control flow operators

cond
fori_loop
map
scan
while_loop

Expand Down
73 changes: 52 additions & 21 deletions jax/lax/lax_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from jax.tree_util import build_tree, tree_unflatten, tree_map
from jax import ad_util

map = safe_map
_map = safe_map
zip = safe_zip


Expand Down Expand Up @@ -220,7 +220,7 @@ def _while_loop_batching_rule(batched_args, batch_dims,
init_val, cond_consts, body_consts = batched_args
init_val_bd, cond_consts_bd, body_consts_bd = batch_dims

sizes = lax._reduce(set.union, map(batching.dimsize, batch_dims, batched_args))
sizes = lax._reduce(set.union, _map(batching.dimsize, batch_dims, batched_args))
size = sizes.pop()
assert not sizes

Expand Down Expand Up @@ -253,7 +253,7 @@ def lifted(loop_carry, cond_consts, body_consts):
def _jaxtupletree_select(pred, on_true, on_false):
aval = core.get_aval(on_true)
if type(aval) is core.AbstractTuple:
return core.pack(map(partial(_jaxtupletree_select, pred), on_true, on_false))
return core.pack(_map(partial(_jaxtupletree_select, pred), on_true, on_false))
elif isinstance(aval, UnshapedArray):
return lax.select(pred, on_true, on_false)
else:
Expand Down Expand Up @@ -402,7 +402,7 @@ def make_computation(jaxpr, operand):

def _maybe_tracer_tuple_to_abstract_tuple(tup):
if isinstance(tup, pe.JaxprTracerTuple):
return core.AbstractTuple(list(map(_maybe_tracer_tuple_to_abstract_tuple, tup)))
return core.AbstractTuple(list(_map(_maybe_tracer_tuple_to_abstract_tuple, tup)))
elif isinstance(tup, core.AbstractValue):
return tup
elif tup is None:
Expand All @@ -424,10 +424,10 @@ def _convert_zeros(instantiate, example, tangent):
raise TypeError(tangent) # not clear if ever reachable
elif t is tuple:
if type(tangent) is ad.TangentTuple:
return core.pack(map(_convert_zeros, instantiate, example, tangent))
return core.pack(_map(_convert_zeros, instantiate, example, tangent))
elif tangent is ad_util.zero:
zeros = [ad_util.zero] * len(instantiate)
return core.pack(map(_convert_zeros, instantiate, example, zeros))
return core.pack(_map(_convert_zeros, instantiate, example, zeros))
else:
raise TypeError(tangent)
else:
Expand All @@ -436,20 +436,20 @@ def _convert_zeros(instantiate, example, tangent):
def _demote_aval_rank(xs):
assert isinstance(xs, core.AbstractValue)
if isinstance(xs, core.AbstractTuple):
return core.AbstractTuple(map(_demote_aval_rank, xs))
return core.AbstractTuple(_map(_demote_aval_rank, xs))
else:
return ShapedArray(xs.shape[1:], xs.dtype)

def _promote_aval_rank(n, xs):
assert isinstance(xs, core.AbstractValue)
if isinstance(xs, core.AbstractTuple):
return core.AbstractTuple(map(partial(_promote_aval_rank, n), xs))
return core.AbstractTuple(_map(partial(_promote_aval_rank, n), xs))
else:
return ShapedArray((n,) + xs.shape, xs.dtype)

def _leading_dim_size(xs):
if isinstance(xs, core.JaxTuple):
sizes = set(map(_leading_dim_size, xs))
sizes = set(_map(_leading_dim_size, xs))
if len(sizes) == 1:
return sizes.pop()
elif len(sizes) > 1:
Expand All @@ -468,21 +468,21 @@ def _leading_dim_size(xs):
def _empty_arrays(aval):
assert isinstance(aval, core.AbstractValue)
if isinstance(aval, core.AbstractTuple):
return core.pack(map(_empty_arrays, aval))
return core.pack(_map(_empty_arrays, aval))
else:
return lax.full(aval.shape, 0, aval.dtype)

def _index_arrays(i, aval, xs):
assert isinstance(aval, core.AbstractValue)
if isinstance(aval, core.AbstractTuple):
return core.pack(map(partial(_index_arrays, i), aval, xs))
return core.pack(_map(partial(_index_arrays, i), aval, xs))
else:
return lax.dynamic_index_in_dim(xs, i, keepdims=False)

def _update_arrays(i, aval, xs, x):
assert isinstance(aval, core.AbstractValue)
if isinstance(aval, core.AbstractTuple):
return core.pack(map(partial(_update_arrays, i), aval, xs, x))
return core.pack(_map(partial(_update_arrays, i), aval, xs, x))
else:
x = lax.reshape(x, (1,) + onp.shape(x))
return lax.dynamic_update_index_in_dim(xs, x, i, axis=0)
Expand Down Expand Up @@ -543,7 +543,7 @@ def scan(f, init, xs):
loop carry value and the second element represents the stacked outputs of
the second output of ``f`` when scanned over the leading axis of the inputs.
"""
(init, xs), in_trees = unzip2(map(pytree_to_jaxtupletree, (init, xs)))
(init, xs), in_trees = unzip2(_map(pytree_to_jaxtupletree, (init, xs)))
f, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(f), in_trees)
carry_pval = carry_aval, _ = _abstractify(init)
xs_aval, _ = _abstractify(xs)
Expand Down Expand Up @@ -639,11 +639,11 @@ def _binary_lattice_fold(f, pack, a, b):
recur = partial(_binary_lattice_fold, f, pack)
t = (type(a), type(b))
if t == (tuple, tuple):
return pack(map(recur, a, b))
return pack(_map(recur, a, b))
elif t == (tuple, bool):
return pack(map(recur, a, (b,) * len(a)))
return pack(_map(recur, a, (b,) * len(a)))
elif t == (bool, tuple):
return pack(map(recur, (a,) * len(b), b))
return pack(_map(recur, (a,) * len(b), b))
elif t == (bool, bool):
return f(a, b)
else:
Expand All @@ -659,7 +659,7 @@ def _scan_partial_eval(trace, *tracers, **kwargs):
forward = kwargs.pop('forward')
assert not kwargs
in_pvs, _ = unzip2([t.pval for t in tracers])
sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs)
sc_consts, sc_init, sc_xs = _map(pe.unknown, in_pvs)

sc_carry = sc_init
for i in range(1000):
Expand Down Expand Up @@ -702,8 +702,8 @@ def _lift_tracer(trace, tracer, is_unknown):
else:
return tracer
elif t is tuple:
tracers = map(trace.full_raise, tracer)
return core.pack(map(partial(_lift_tracer, trace), tracers, is_unknown))
tracers = _map(trace.full_raise, tracer)
return core.pack(_map(partial(_lift_tracer, trace), tracers, is_unknown))
else:
raise TypeError(t)

Expand All @@ -713,7 +713,7 @@ def _put_known_pvs(is_unknown, aval):
elif is_unknown is True:
return aval
else:
return pe.JaxprTracerTuple(map(_put_known_pvs, is_unknown, aval))
return pe.JaxprTracerTuple(_map(_put_known_pvs, is_unknown, aval))


def _scan_transpose(ct, consts, init, xs, forward, length, jaxpr):
Expand Down Expand Up @@ -836,7 +836,7 @@ def _scan_batching_rule(batched_args, batch_dims, forward, length, jaxpr):
consts, init, xs = batched_args
consts_bdim, init_bdim, xs_bdim = batch_dims

sizes = lax._reduce(set.union, map(batching.dimsize, batch_dims, batched_args))
sizes = lax._reduce(set.union, _map(batching.dimsize, batch_dims, batched_args))
size = sizes.pop()
assert not sizes

Expand Down Expand Up @@ -889,3 +889,34 @@ def scan_bind(consts, init, xs, forward, length, jaxpr):
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
xla.initial_style_translations[scan_p] = xla.lower_fun(_scan_impl, initial_style=True)
batching.primitive_batchers[scan_p] = _scan_batching_rule


def map(f, xs):
"""Map a function over leading array axes.

Like Python's builtin map, except inputs and outputs are in the form of
stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you
need to apply a function element by element for reduced memory usage or
heterogeneous computation with other control flow primitives.

When ``xs`` is an array type, the semantics of ``map`` are given by this
Python implementation::

def map(f, xs):
return np.stack([f(x) for x in xs])

Like ``scan``, ``map`` is implemented in terms of JAX primtivies so many of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

primtivies -> primitives

the same advantages over a Python loop apply: ``xs`` may be an arbitrary
nested pytree type, and the mapped computation is compiled only once.

Args:
f: a Python function to apply element-wise over the first axis or axes of
``xs``.
xs: values over which to map along the leading axis.

Returns:
Mapped values.
"""
g = lambda _, x: ((), f(x))
_, ys = scan(g, (), xs)
return ys
7 changes: 7 additions & 0 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,13 @@ def testIssue804(self):
f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.)
api.pmap(f, axis_name="i")(np.ones((num_devices, 4))) # doesn't crash

def testMap(self):
f = lambda x: x ** 2
xs = np.arange(10)
expected = xs ** 2
actual = lax.map(f, xs)
self.assertAllClose(actual, expected, check_dtypes=True)


if __name__ == '__main__':
absltest.main()