diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 3e6dc45063ac..8f74f3414a54 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -133,6 +133,7 @@ Control flow operators cond fori_loop + map scan while_loop diff --git a/jax/lax/lax_control_flow.py b/jax/lax/lax_control_flow.py index 1fae92fe017e..09bf8e27a655 100644 --- a/jax/lax/lax_control_flow.py +++ b/jax/lax/lax_control_flow.py @@ -40,7 +40,7 @@ from jax.tree_util import build_tree, tree_unflatten, tree_map from jax import ad_util -map = safe_map +_map = safe_map zip = safe_zip @@ -220,7 +220,7 @@ def _while_loop_batching_rule(batched_args, batch_dims, init_val, cond_consts, body_consts = batched_args init_val_bd, cond_consts_bd, body_consts_bd = batch_dims - sizes = lax._reduce(set.union, map(batching.dimsize, batch_dims, batched_args)) + sizes = lax._reduce(set.union, _map(batching.dimsize, batch_dims, batched_args)) size = sizes.pop() assert not sizes @@ -253,7 +253,7 @@ def lifted(loop_carry, cond_consts, body_consts): def _jaxtupletree_select(pred, on_true, on_false): aval = core.get_aval(on_true) if type(aval) is core.AbstractTuple: - return core.pack(map(partial(_jaxtupletree_select, pred), on_true, on_false)) + return core.pack(_map(partial(_jaxtupletree_select, pred), on_true, on_false)) elif isinstance(aval, UnshapedArray): return lax.select(pred, on_true, on_false) else: @@ -402,7 +402,7 @@ def make_computation(jaxpr, operand): def _maybe_tracer_tuple_to_abstract_tuple(tup): if isinstance(tup, pe.JaxprTracerTuple): - return core.AbstractTuple(list(map(_maybe_tracer_tuple_to_abstract_tuple, tup))) + return core.AbstractTuple(list(_map(_maybe_tracer_tuple_to_abstract_tuple, tup))) elif isinstance(tup, core.AbstractValue): return tup elif tup is None: @@ -424,10 +424,10 @@ def _convert_zeros(instantiate, example, tangent): raise TypeError(tangent) # not clear if ever reachable elif t is tuple: if type(tangent) is ad.TangentTuple: - return core.pack(map(_convert_zeros, instantiate, example, tangent)) + return core.pack(_map(_convert_zeros, instantiate, example, tangent)) elif tangent is ad_util.zero: zeros = [ad_util.zero] * len(instantiate) - return core.pack(map(_convert_zeros, instantiate, example, zeros)) + return core.pack(_map(_convert_zeros, instantiate, example, zeros)) else: raise TypeError(tangent) else: @@ -436,20 +436,20 @@ def _convert_zeros(instantiate, example, tangent): def _demote_aval_rank(xs): assert isinstance(xs, core.AbstractValue) if isinstance(xs, core.AbstractTuple): - return core.AbstractTuple(map(_demote_aval_rank, xs)) + return core.AbstractTuple(_map(_demote_aval_rank, xs)) else: return ShapedArray(xs.shape[1:], xs.dtype) def _promote_aval_rank(n, xs): assert isinstance(xs, core.AbstractValue) if isinstance(xs, core.AbstractTuple): - return core.AbstractTuple(map(partial(_promote_aval_rank, n), xs)) + return core.AbstractTuple(_map(partial(_promote_aval_rank, n), xs)) else: return ShapedArray((n,) + xs.shape, xs.dtype) def _leading_dim_size(xs): if isinstance(xs, core.JaxTuple): - sizes = set(map(_leading_dim_size, xs)) + sizes = set(_map(_leading_dim_size, xs)) if len(sizes) == 1: return sizes.pop() elif len(sizes) > 1: @@ -468,21 +468,21 @@ def _leading_dim_size(xs): def _empty_arrays(aval): assert isinstance(aval, core.AbstractValue) if isinstance(aval, core.AbstractTuple): - return core.pack(map(_empty_arrays, aval)) + return core.pack(_map(_empty_arrays, aval)) else: return lax.full(aval.shape, 0, aval.dtype) def _index_arrays(i, aval, xs): assert isinstance(aval, core.AbstractValue) if isinstance(aval, core.AbstractTuple): - return core.pack(map(partial(_index_arrays, i), aval, xs)) + return core.pack(_map(partial(_index_arrays, i), aval, xs)) else: return lax.dynamic_index_in_dim(xs, i, keepdims=False) def _update_arrays(i, aval, xs, x): assert isinstance(aval, core.AbstractValue) if isinstance(aval, core.AbstractTuple): - return core.pack(map(partial(_update_arrays, i), aval, xs, x)) + return core.pack(_map(partial(_update_arrays, i), aval, xs, x)) else: x = lax.reshape(x, (1,) + onp.shape(x)) return lax.dynamic_update_index_in_dim(xs, x, i, axis=0) @@ -543,7 +543,7 @@ def scan(f, init, xs): loop carry value and the second element represents the stacked outputs of the second output of ``f`` when scanned over the leading axis of the inputs. """ - (init, xs), in_trees = unzip2(map(pytree_to_jaxtupletree, (init, xs))) + (init, xs), in_trees = unzip2(_map(pytree_to_jaxtupletree, (init, xs))) f, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(f), in_trees) carry_pval = carry_aval, _ = _abstractify(init) xs_aval, _ = _abstractify(xs) @@ -639,11 +639,11 @@ def _binary_lattice_fold(f, pack, a, b): recur = partial(_binary_lattice_fold, f, pack) t = (type(a), type(b)) if t == (tuple, tuple): - return pack(map(recur, a, b)) + return pack(_map(recur, a, b)) elif t == (tuple, bool): - return pack(map(recur, a, (b,) * len(a))) + return pack(_map(recur, a, (b,) * len(a))) elif t == (bool, tuple): - return pack(map(recur, (a,) * len(b), b)) + return pack(_map(recur, (a,) * len(b), b)) elif t == (bool, bool): return f(a, b) else: @@ -659,7 +659,7 @@ def _scan_partial_eval(trace, *tracers, **kwargs): forward = kwargs.pop('forward') assert not kwargs in_pvs, _ = unzip2([t.pval for t in tracers]) - sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs) + sc_consts, sc_init, sc_xs = _map(pe.unknown, in_pvs) sc_carry = sc_init for i in range(1000): @@ -702,8 +702,8 @@ def _lift_tracer(trace, tracer, is_unknown): else: return tracer elif t is tuple: - tracers = map(trace.full_raise, tracer) - return core.pack(map(partial(_lift_tracer, trace), tracers, is_unknown)) + tracers = _map(trace.full_raise, tracer) + return core.pack(_map(partial(_lift_tracer, trace), tracers, is_unknown)) else: raise TypeError(t) @@ -713,7 +713,7 @@ def _put_known_pvs(is_unknown, aval): elif is_unknown is True: return aval else: - return pe.JaxprTracerTuple(map(_put_known_pvs, is_unknown, aval)) + return pe.JaxprTracerTuple(_map(_put_known_pvs, is_unknown, aval)) def _scan_transpose(ct, consts, init, xs, forward, length, jaxpr): @@ -836,7 +836,7 @@ def _scan_batching_rule(batched_args, batch_dims, forward, length, jaxpr): consts, init, xs = batched_args consts_bdim, init_bdim, xs_bdim = batch_dims - sizes = lax._reduce(set.union, map(batching.dimsize, batch_dims, batched_args)) + sizes = lax._reduce(set.union, _map(batching.dimsize, batch_dims, batched_args)) size = sizes.pop() assert not sizes @@ -889,3 +889,34 @@ def scan_bind(consts, init, xs, forward, length, jaxpr): pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval xla.initial_style_translations[scan_p] = xla.lower_fun(_scan_impl, initial_style=True) batching.primitive_batchers[scan_p] = _scan_batching_rule + + +def map(f, xs): + """Map a function over leading array axes. + + Like Python's builtin map, except inputs and outputs are in the form of + stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you + need to apply a function element by element for reduced memory usage or + heterogeneous computation with other control flow primitives. + + When ``xs`` is an array type, the semantics of ``map`` are given by this + Python implementation:: + + def map(f, xs): + return np.stack([f(x) for x in xs]) + + Like ``scan``, ``map`` is implemented in terms of JAX primitives so many of + the same advantages over a Python loop apply: ``xs`` may be an arbitrary + nested pytree type, and the mapped computation is compiled only once. + + Args: + f: a Python function to apply element-wise over the first axis or axes of + ``xs``. + xs: values over which to map along the leading axis. + + Returns: + Mapped values. + """ + g = lambda _, x: ((), f(x)) + _, ys = scan(g, (), xs) + return ys diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 4b9c1d2f93ef..a6d07298e39a 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -816,6 +816,13 @@ def testIssue804(self): f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.) api.pmap(f, axis_name="i")(np.ones((num_devices, 4))) # doesn't crash + def testMap(self): + f = lambda x: x ** 2 + xs = np.arange(10) + expected = xs ** 2 + actual = lax.map(f, xs) + self.assertAllClose(actual, expected, check_dtypes=True) + if __name__ == '__main__': absltest.main()