Skip to content

Commit

Permalink
Merge pull request #9369 from froystig:custom-vmap-outputs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 438111265
  • Loading branch information
jax authors committed Mar 29, 2022
2 parents 4b4010d + b2de101 commit 085d390
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 69 deletions.
60 changes: 30 additions & 30 deletions jax/_src/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import functools
import operator
from typing import Callable, Optional, Sequence
from typing import Callable, Optional

import jax
from jax import core
Expand Down Expand Up @@ -71,7 +71,8 @@ def __call__(self, *args, **kwargs):
out_flat = custom_vmap_p.bind(*consts, *args_flat,
call=closed_call,
rule=self.vmap_rule,
in_tree=in_tree)
in_tree=in_tree,
out_tree=out_tree())
return tree_unflatten(out_tree(), out_flat)


Expand All @@ -85,24 +86,21 @@ def rule_name(rule):
return getattr(rule, '__name__', '<unnamed rule>')

def call_rule(rule, axis_size, in_batched, args):
outs, out_batched = rule(axis_size, ensure_list(in_batched), *args)
if not isinstance(outs, Sequence):
raise TypeError(
'custom vmap rule output values must be a sequence, '
f'rule ({rule_name(rule)}) returned {type(outs)}')
if not isinstance(out_batched, Sequence):
raise TypeError(
'custom vmap rule output batching specification must be a sequence, '
f'rule ({rule_name(rule)}) returned {type(out_batched)}')
return ensure_list(outs), ensure_list(out_batched)

def check_vmap_rule_trees(rule, out_tree, out_batched_tree):
return rule(axis_size, ensure_list(in_batched), *args)

def check_vmap_rule_trees(rule, original_out_tree, out_tree, out_batched_tree):
if out_tree != out_batched_tree:
raise ValueError(
'structure of output values and output batching specification returned '
'structure of output value and output batching specification returned '
f'by custom vmap rule ({rule_name(rule)}) do not match.\n'
f'Output values: {out_tree}\n'
f'Batching spec: {out_batched_tree}')
if out_tree != original_out_tree:
raise ValueError(
f'structure of output returned by custom vmap rule ({rule_name(rule)}) '
'does not match that of original custom-vmapped function.\n'
f'Original output: {original_out_tree}\n'
f'Rule output: {out_tree}')

# Like batching.bdim_at_front, but doesn't broadcast if not mapped
def maybe_bdim_at_front(x, bdim):
Expand All @@ -127,23 +125,23 @@ def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size):
### custom_vmap_p rules


def custom_vmap_impl(*args, call, rule, in_tree):
del rule, in_tree
def custom_vmap_impl(*args, call, rule, in_tree, out_tree):
del rule, in_tree, out_tree
return core.jaxpr_as_fun(call)(*args)


def custom_vmap_batching(args_flat, dims, *, call, rule, in_tree):
def custom_vmap_batching(args_flat, dims, *, call, rule, in_tree, out_tree):
del call
axis_size, = {x.shape[d] for x, d in zip(args_flat, dims) if d is not None}
args_flat = map(maybe_bdim_at_front, args_flat, dims)
flat_in_batched = [d is not not_mapped for d in dims]

args = tree_unflatten(in_tree, args_flat)
in_batched = tree_unflatten(in_tree, flat_in_batched)
outs, out_batched = call_rule(rule, axis_size, in_batched, args)
flat_outs, tree1 = tree_flatten(outs)
out, out_batched = call_rule(rule, axis_size, in_batched, args)
flat_outs, tree1 = tree_flatten(out)
flat_out_batched, tree2 = tree_flatten(out_batched)
check_vmap_rule_trees(rule, tree1, tree2)
check_vmap_rule_trees(rule, out_tree, tree1, tree2)
flat_out_dims = [0 if b else not_mapped for b in flat_out_batched]
return flat_outs, flat_out_dims

Expand All @@ -152,7 +150,7 @@ def custom_vmap_abstract_eval(*in_avals, call, **_):
return call.out_avals


def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree):
def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree):
def jvp_of_rule_rule(axis_size, in_batched, primals, tangents):
in_batched_ps, in_batched_ts = in_batched

Expand All @@ -175,16 +173,16 @@ def jvp_of_rule_rule(axis_size, in_batched, primals, tangents):
del tree_ps_ts2

def to_jvp(*primals):
outs, out_batched = call_rule(rule, axis_size, mutually_batched, primals)
out, out_batched = call_rule(rule, axis_size, mutually_batched, primals)
check_vmap_rule_trees(
rule, tree_structure(outs), tree_structure(out_batched))
rule, out_tree, tree_structure(out), tree_structure(out_batched))
out_mutually_batched.store(out_batched)
return outs
return out

def to_vmap_over_extra_batched_dims(primals, tangents):
return jax.jvp(to_jvp, primals, tangents)

to_vmap_over_extra_batched_dims_flat, out_tree = flatten_fun_nokwargs(
to_vmap_over_extra_batched_dims_flat, out_tree2 = flatten_fun_nokwargs(
lu.wrap_init(to_vmap_over_extra_batched_dims),
tree_ps_ts)

Expand All @@ -203,9 +201,9 @@ def to_vmap_over_extra_batched_dims(primals, tangents):
flat_out_extra_batched_ts = [d is not not_mapped for d in flat_out_axes_t]

out_ps, out_ts = tree_unflatten(
out_tree(), [*flat_out_ps, *flat_out_ts])
out_tree2(), [*flat_out_ps, *flat_out_ts])
out_extra_batched_ps, out_extra_batched_ts = tree_unflatten(
out_tree(), [*flat_out_extra_batched_ps, *flat_out_extra_batched_ts])
out_tree2(), [*flat_out_extra_batched_ps, *flat_out_extra_batched_ts])

out_batched_ps = tree_map(
operator.or_, out_mutually_batched.val, out_extra_batched_ps)
Expand All @@ -217,9 +215,11 @@ def to_vmap_over_extra_batched_dims(primals, tangents):
tangents = map(ad.instantiate_zeros, tangents)
jvp_call, _ = ad.jvp_jaxpr(call, [True] * len(primals), True)
jvp_in_tree = treedef_tuple((in_tree, in_tree))
jvp_out_tree = treedef_tuple((out_tree, out_tree))
outs = custom_vmap_p.bind(
*primals, *tangents,
call=jvp_call, rule=jvp_of_rule_rule, in_tree=jvp_in_tree)
call=jvp_call, rule=jvp_of_rule_rule,
in_tree=jvp_in_tree, out_tree=jvp_out_tree)
assert len(outs) % 2 == 0, len(outs)
out_primals, out_tangents = util.split_list(outs, [len(outs) // 2])
return out_primals, out_tangents
Expand Down Expand Up @@ -265,6 +265,6 @@ def to_map(mapped_args):
mapped_args, bcast_args = tree_split(in_batched, list(args))
out = jax.lax.map(to_map, mapped_args)
out_batched = tree_map(lambda _: True, out)
return [out], [out_batched]
return out, out_batched

return f
81 changes: 42 additions & 39 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6971,16 +6971,33 @@ def f(x): return jnp.sin(x)

@f.def_vmap
def rule(axis_size, in_batched, xs):
self.assertEqual(in_batched, [True])
xs_batched, = in_batched
self.assertEqual(xs_batched, True)
self.assertEqual(axis_size, xs.shape[0])
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), xs_batched

x, xs = jnp.array(1.), jnp.arange(3)
y = f(x)
self.assertAllClose(y, jnp.sin(x))
ys = api.vmap(f)(xs)
self.assertAllClose(ys, jnp.cos(xs))

def test_rule_multi_output(self):
@api.custom_vmap
def f(x): return jnp.sin(x), jnp.cos(x)

@f.def_vmap
def rule(axis_size, in_batched, xs):
return (jnp.cos(xs), jnp.sin(xs)), tuple(in_batched * 2)

x, xs = jnp.array(1.), jnp.arange(3)
y1, y2 = f(x)
self.assertAllClose(y1, jnp.sin(x))
self.assertAllClose(y2, jnp.cos(x))
ys1, ys2 = api.vmap(f)(xs)
self.assertAllClose(ys1, jnp.cos(xs))
self.assertAllClose(ys2, jnp.sin(xs))

def test_nary(self):
@api.custom_vmap
def f(x, y): return jnp.sin(x) + y ** 2.
Expand All @@ -6991,7 +7008,7 @@ def rule(axis_size, in_batched, xs, ys):
self.assertEqual(axis_size, 3)
self.assertEqual(axis_size, xs.shape[0])
self.assertEqual(axis_size, ys.shape[0])
return [jnp.cos(xs) + ys ** 2.], [True]
return jnp.cos(xs) + ys ** 2., True

xs, ys = jnp.arange(3), jnp.arange(3)
zs = api.vmap(f)(xs, ys)
Expand Down Expand Up @@ -7029,7 +7046,7 @@ def vector_dot_vmap_rule(axis_size, in_batched, u, v):
out = jnp.sum(u * v, axis=1)
else:
out = u @ v if u_batched else v @ u
return [out], [u_batched or v_batched]
return out, u_batched or v_batched

f = vector_dot
v = lambda *shape: jnp.ones(shape)
Expand All @@ -7056,30 +7073,16 @@ def f(x): return jnp.sin(x)
@f.def_vmap
def rule(axis_size, in_batched, xs):
rule_args.append((axis_size, in_batched))
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), in_batched[0]

xs = jnp.arange(3)
_ = api.vmap(f)(xs)
(axis_size, in_batched), = rule_args
self.assertIs(type(axis_size), int)
self.assertIs(type(in_batched), list)
self.assertEqual(len(in_batched), 1)

def test_rule_output_signature_any_sequence(self):
@api.custom_vmap
def f(x): return jnp.sin(x)

Box = collections.namedtuple('Box', 'value')

@f.def_vmap
def rule(axis_size, in_batched, xs):
# custom vmap machinery should handle any sequence type for either output
return Box(jnp.cos(xs)), tuple(in_batched)

xs = jnp.arange(3)
ys = api.vmap(f)(xs)
self.assertAllClose(ys, jnp.cos(xs))

def test_rule_output_mismatch(self):
def test_rule_output_vs_batching_output_mismatch(self):
@api.custom_vmap
def f(x): return jnp.sin(x)

Expand All @@ -7090,23 +7093,23 @@ def test_rule_abc(axis_size, in_batched, xs):
xs = jnp.arange(3)
self.assertRaisesRegex(
ValueError,
'structure of output values and output batching specification '
'structure of output value and output batching specification '
r'returned by custom vmap rule \(test_rule_abc\) do not match.*',
lambda: api.vmap(f)(xs))

def test_rule_output_array(self):
def test_rule_vs_call_output_mismatch(self):
@api.custom_vmap
def f(x): return jnp.sin(x)

@f.def_vmap
def rule(axis_size, in_batched, xs):
# common to overlook the need to box up single output value in a list
return jnp.cos(xs), in_batched
def test_rule_abc2(axis_size, in_batched, xs):
return [jnp.sin(xs)], in_batched

xs = jnp.arange(3)
self.assertRaisesRegex(
TypeError,
'custom vmap rule output values must be a sequence.*',
ValueError,
r'structure of output returned by custom vmap rule \(test_rule_abc2\) '
r'does not match that of original custom-vmapped function.*',
lambda: api.vmap(f)(xs))

def test_jvp_basic(self):
Expand All @@ -7117,7 +7120,7 @@ def f(x): return jnp.sin(x)
def rule(axis_size, in_batched, xs):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [True])
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), in_batched[0]

f_jvp = lambda x, tx: api.jvp(f, [x], [tx])

Expand All @@ -7144,7 +7147,7 @@ def f(x, y): return jnp.sin(x) + y
def rule(axis_size, in_batched, xs, ys):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [True, True])
return [jnp.cos(xs) + ys], [True]
return jnp.cos(xs) + ys, True

f_jvp = lambda x, y, tx, ty: api.jvp(f, [x, y], [tx, ty])

Expand All @@ -7167,7 +7170,7 @@ def f(x): return jnp.sin(x)
def rule(axis_size, in_batched, xs):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [False])
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), in_batched[0]

f_jvp = lambda x, tx: api.jvp(f, [x], [tx])

Expand All @@ -7186,7 +7189,7 @@ def f(x): return jnp.sin(x)
def rule(axis_size, in_batched, xs):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [False])
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), in_batched[0]

x = jnp.arange(3.) + .72
j = api.jacfwd(f)(x)
Expand All @@ -7200,7 +7203,7 @@ def f(x): return jnp.sin(x)
def rule(axis_size, in_batched, xs):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [False])
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), in_batched[0]

f_jvp = lambda x, tx: api.jvp(f, [x], [tx])

Expand All @@ -7223,14 +7226,14 @@ def f_linear(x): return 7. * x

@f_linear.def_vmap
def linear_rule(axis_size, in_batched, xs):
return [11. * xs], in_batched
return 11. * xs, in_batched[0]

@api.custom_vmap
def f_nonlinear(x): return jnp.sin(x)

@f_nonlinear.def_vmap
def nonlinear_rule(axis_size, in_batched, xs):
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), in_batched[0]

f_lin_jvp = lambda x, tx: api.jvp(f_linear, [x], [tx])
f_non_jvp = lambda x, tx: api.jvp(f_nonlinear, [x], [tx])
Expand Down Expand Up @@ -7267,7 +7270,7 @@ def f(x): return jnp.sin(x)

@f.def_vmap
def rule(axis_size, in_batched, xs):
return [cos_with_invalid_dataflow_jvp(xs)], in_batched
return cos_with_invalid_dataflow_jvp(xs), in_batched[0]

f_jvp = lambda x, tx: api.jvp(f, [x], [tx])
x, txs = jnp.array(1.), 2. + jnp.arange(3.)
Expand Down Expand Up @@ -7300,7 +7303,7 @@ def rule(axis_size, in_batched, xs):
self.assertEqual(in_batched, [in_batched_ref])
sz, = set([z.shape[0] for z in tree_util.tree_leaves(xs)])
self.assertEqual(axis_size, sz)
return [tree_cos(xs)], in_batched
return tree_cos(xs), in_batched[0]

y = f(x)
self.assertAllClose(y, tree_sin(x))
Expand All @@ -7324,7 +7327,7 @@ def rule(axis_size, in_batched, xs):
self.assertEqual(in_batched, [in_batched_ref])
sz, = set([z.shape[0] for z in tree_util.tree_leaves(xs)])
self.assertEqual(axis_size, sz)
return [tree_cos(xs)], in_batched
return tree_cos(xs), in_batched[0]

y = f(x)
self.assertAllClose(y, tree_sin(x))
Expand All @@ -7339,7 +7342,7 @@ def f(x): return jnp.sin(x)
def rule(axis_size, in_batched, xs):
self.assertEqual(in_batched, [True])
self.assertEqual(axis_size, xs.shape[0])
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), in_batched[0]

x, xs = jnp.array(1.), jnp.arange(3)
self.assertAllClose(f(x), jit(f)(x))
Expand Down

0 comments on commit 085d390

Please sign in to comment.