Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

require consistent output structure in custom vmap rules #9369

Merged
merged 1 commit into from
Mar 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions jax/_src/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import functools
import operator
from typing import Callable, Optional, Sequence
from typing import Callable, Optional

import jax
from jax import core
Expand Down Expand Up @@ -71,7 +71,8 @@ def __call__(self, *args, **kwargs):
out_flat = custom_vmap_p.bind(*consts, *args_flat,
call=closed_call,
rule=self.vmap_rule,
in_tree=in_tree)
in_tree=in_tree,
out_tree=out_tree())
return tree_unflatten(out_tree(), out_flat)


Expand All @@ -85,24 +86,21 @@ def rule_name(rule):
return getattr(rule, '__name__', '<unnamed rule>')

def call_rule(rule, axis_size, in_batched, args):
outs, out_batched = rule(axis_size, ensure_list(in_batched), *args)
if not isinstance(outs, Sequence):
raise TypeError(
'custom vmap rule output values must be a sequence, '
f'rule ({rule_name(rule)}) returned {type(outs)}')
if not isinstance(out_batched, Sequence):
raise TypeError(
'custom vmap rule output batching specification must be a sequence, '
f'rule ({rule_name(rule)}) returned {type(out_batched)}')
return ensure_list(outs), ensure_list(out_batched)

def check_vmap_rule_trees(rule, out_tree, out_batched_tree):
return rule(axis_size, ensure_list(in_batched), *args)

def check_vmap_rule_trees(rule, original_out_tree, out_tree, out_batched_tree):
if out_tree != out_batched_tree:
raise ValueError(
'structure of output values and output batching specification returned '
'structure of output value and output batching specification returned '
f'by custom vmap rule ({rule_name(rule)}) do not match.\n'
f'Output values: {out_tree}\n'
f'Batching spec: {out_batched_tree}')
if out_tree != original_out_tree:
raise ValueError(
f'structure of output returned by custom vmap rule ({rule_name(rule)}) '
'does not match that of original custom-vmapped function.\n'
f'Original output: {original_out_tree}\n'
f'Rule output: {out_tree}')

# Like batching.bdim_at_front, but doesn't broadcast if not mapped
def maybe_bdim_at_front(x, bdim):
Expand All @@ -127,23 +125,23 @@ def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size):
### custom_vmap_p rules


def custom_vmap_impl(*args, call, rule, in_tree):
del rule, in_tree
def custom_vmap_impl(*args, call, rule, in_tree, out_tree):
del rule, in_tree, out_tree
return core.jaxpr_as_fun(call)(*args)


def custom_vmap_batching(args_flat, dims, *, call, rule, in_tree):
def custom_vmap_batching(args_flat, dims, *, call, rule, in_tree, out_tree):
del call
axis_size, = {x.shape[d] for x, d in zip(args_flat, dims) if d is not None}
args_flat = map(maybe_bdim_at_front, args_flat, dims)
flat_in_batched = [d is not not_mapped for d in dims]

args = tree_unflatten(in_tree, args_flat)
in_batched = tree_unflatten(in_tree, flat_in_batched)
outs, out_batched = call_rule(rule, axis_size, in_batched, args)
flat_outs, tree1 = tree_flatten(outs)
out, out_batched = call_rule(rule, axis_size, in_batched, args)
flat_outs, tree1 = tree_flatten(out)
flat_out_batched, tree2 = tree_flatten(out_batched)
check_vmap_rule_trees(rule, tree1, tree2)
check_vmap_rule_trees(rule, out_tree, tree1, tree2)
flat_out_dims = [0 if b else not_mapped for b in flat_out_batched]
return flat_outs, flat_out_dims

Expand All @@ -152,7 +150,7 @@ def custom_vmap_abstract_eval(*in_avals, call, **_):
return call.out_avals


def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree):
def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree):
def jvp_of_rule_rule(axis_size, in_batched, primals, tangents):
in_batched_ps, in_batched_ts = in_batched

Expand All @@ -175,16 +173,16 @@ def jvp_of_rule_rule(axis_size, in_batched, primals, tangents):
del tree_ps_ts2

def to_jvp(*primals):
outs, out_batched = call_rule(rule, axis_size, mutually_batched, primals)
out, out_batched = call_rule(rule, axis_size, mutually_batched, primals)
check_vmap_rule_trees(
rule, tree_structure(outs), tree_structure(out_batched))
rule, out_tree, tree_structure(out), tree_structure(out_batched))
out_mutually_batched.store(out_batched)
return outs
return out

def to_vmap_over_extra_batched_dims(primals, tangents):
return jax.jvp(to_jvp, primals, tangents)

to_vmap_over_extra_batched_dims_flat, out_tree = flatten_fun_nokwargs(
to_vmap_over_extra_batched_dims_flat, out_tree2 = flatten_fun_nokwargs(
lu.wrap_init(to_vmap_over_extra_batched_dims),
tree_ps_ts)

Expand All @@ -203,9 +201,9 @@ def to_vmap_over_extra_batched_dims(primals, tangents):
flat_out_extra_batched_ts = [d is not not_mapped for d in flat_out_axes_t]

out_ps, out_ts = tree_unflatten(
out_tree(), [*flat_out_ps, *flat_out_ts])
out_tree2(), [*flat_out_ps, *flat_out_ts])
out_extra_batched_ps, out_extra_batched_ts = tree_unflatten(
out_tree(), [*flat_out_extra_batched_ps, *flat_out_extra_batched_ts])
out_tree2(), [*flat_out_extra_batched_ps, *flat_out_extra_batched_ts])

out_batched_ps = tree_map(
operator.or_, out_mutually_batched.val, out_extra_batched_ps)
Expand All @@ -217,9 +215,11 @@ def to_vmap_over_extra_batched_dims(primals, tangents):
tangents = map(ad.instantiate_zeros, tangents)
jvp_call, _ = ad.jvp_jaxpr(call, [True] * len(primals), True)
jvp_in_tree = treedef_tuple((in_tree, in_tree))
jvp_out_tree = treedef_tuple((out_tree, out_tree))
outs = custom_vmap_p.bind(
*primals, *tangents,
call=jvp_call, rule=jvp_of_rule_rule, in_tree=jvp_in_tree)
call=jvp_call, rule=jvp_of_rule_rule,
in_tree=jvp_in_tree, out_tree=jvp_out_tree)
assert len(outs) % 2 == 0, len(outs)
out_primals, out_tangents = util.split_list(outs, [len(outs) // 2])
return out_primals, out_tangents
Expand Down Expand Up @@ -265,6 +265,6 @@ def to_map(mapped_args):
mapped_args, bcast_args = tree_split(in_batched, list(args))
out = jax.lax.map(to_map, mapped_args)
out_batched = tree_map(lambda _: True, out)
return [out], [out_batched]
return out, out_batched

return f
81 changes: 42 additions & 39 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6971,16 +6971,33 @@ def f(x): return jnp.sin(x)

@f.def_vmap
def rule(axis_size, in_batched, xs):
self.assertEqual(in_batched, [True])
xs_batched, = in_batched
self.assertEqual(xs_batched, True)
self.assertEqual(axis_size, xs.shape[0])
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), xs_batched

x, xs = jnp.array(1.), jnp.arange(3)
y = f(x)
self.assertAllClose(y, jnp.sin(x))
ys = api.vmap(f)(xs)
self.assertAllClose(ys, jnp.cos(xs))

def test_rule_multi_output(self):
@api.custom_vmap
def f(x): return jnp.sin(x), jnp.cos(x)

@f.def_vmap
def rule(axis_size, in_batched, xs):
return (jnp.cos(xs), jnp.sin(xs)), tuple(in_batched * 2)

x, xs = jnp.array(1.), jnp.arange(3)
y1, y2 = f(x)
self.assertAllClose(y1, jnp.sin(x))
self.assertAllClose(y2, jnp.cos(x))
ys1, ys2 = api.vmap(f)(xs)
self.assertAllClose(ys1, jnp.cos(xs))
self.assertAllClose(ys2, jnp.sin(xs))

def test_nary(self):
@api.custom_vmap
def f(x, y): return jnp.sin(x) + y ** 2.
Expand All @@ -6991,7 +7008,7 @@ def rule(axis_size, in_batched, xs, ys):
self.assertEqual(axis_size, 3)
self.assertEqual(axis_size, xs.shape[0])
self.assertEqual(axis_size, ys.shape[0])
return [jnp.cos(xs) + ys ** 2.], [True]
return jnp.cos(xs) + ys ** 2., True

xs, ys = jnp.arange(3), jnp.arange(3)
zs = api.vmap(f)(xs, ys)
Expand Down Expand Up @@ -7029,7 +7046,7 @@ def vector_dot_vmap_rule(axis_size, in_batched, u, v):
out = jnp.sum(u * v, axis=1)
else:
out = u @ v if u_batched else v @ u
return [out], [u_batched or v_batched]
return out, u_batched or v_batched

f = vector_dot
v = lambda *shape: jnp.ones(shape)
Expand All @@ -7056,30 +7073,16 @@ def f(x): return jnp.sin(x)
@f.def_vmap
def rule(axis_size, in_batched, xs):
rule_args.append((axis_size, in_batched))
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), in_batched[0]

xs = jnp.arange(3)
_ = api.vmap(f)(xs)
(axis_size, in_batched), = rule_args
self.assertIs(type(axis_size), int)
self.assertIs(type(in_batched), list)
self.assertEqual(len(in_batched), 1)

def test_rule_output_signature_any_sequence(self):
@api.custom_vmap
def f(x): return jnp.sin(x)

Box = collections.namedtuple('Box', 'value')

@f.def_vmap
def rule(axis_size, in_batched, xs):
# custom vmap machinery should handle any sequence type for either output
return Box(jnp.cos(xs)), tuple(in_batched)

xs = jnp.arange(3)
ys = api.vmap(f)(xs)
self.assertAllClose(ys, jnp.cos(xs))

def test_rule_output_mismatch(self):
def test_rule_output_vs_batching_output_mismatch(self):
@api.custom_vmap
def f(x): return jnp.sin(x)

Expand All @@ -7090,23 +7093,23 @@ def test_rule_abc(axis_size, in_batched, xs):
xs = jnp.arange(3)
self.assertRaisesRegex(
ValueError,
'structure of output values and output batching specification '
'structure of output value and output batching specification '
r'returned by custom vmap rule \(test_rule_abc\) do not match.*',
lambda: api.vmap(f)(xs))

def test_rule_output_array(self):
def test_rule_vs_call_output_mismatch(self):
@api.custom_vmap
def f(x): return jnp.sin(x)

@f.def_vmap
def rule(axis_size, in_batched, xs):
# common to overlook the need to box up single output value in a list
return jnp.cos(xs), in_batched
def test_rule_abc2(axis_size, in_batched, xs):
return [jnp.sin(xs)], in_batched

xs = jnp.arange(3)
self.assertRaisesRegex(
TypeError,
'custom vmap rule output values must be a sequence.*',
ValueError,
r'structure of output returned by custom vmap rule \(test_rule_abc2\) '
r'does not match that of original custom-vmapped function.*',
lambda: api.vmap(f)(xs))

def test_jvp_basic(self):
Expand All @@ -7117,7 +7120,7 @@ def f(x): return jnp.sin(x)
def rule(axis_size, in_batched, xs):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [True])
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), in_batched[0]

f_jvp = lambda x, tx: api.jvp(f, [x], [tx])

Expand All @@ -7144,7 +7147,7 @@ def f(x, y): return jnp.sin(x) + y
def rule(axis_size, in_batched, xs, ys):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [True, True])
return [jnp.cos(xs) + ys], [True]
return jnp.cos(xs) + ys, True

f_jvp = lambda x, y, tx, ty: api.jvp(f, [x, y], [tx, ty])

Expand All @@ -7167,7 +7170,7 @@ def f(x): return jnp.sin(x)
def rule(axis_size, in_batched, xs):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [False])
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), in_batched[0]

f_jvp = lambda x, tx: api.jvp(f, [x], [tx])

Expand All @@ -7186,7 +7189,7 @@ def f(x): return jnp.sin(x)
def rule(axis_size, in_batched, xs):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [False])
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), in_batched[0]

x = jnp.arange(3.) + .72
j = api.jacfwd(f)(x)
Expand All @@ -7200,7 +7203,7 @@ def f(x): return jnp.sin(x)
def rule(axis_size, in_batched, xs):
self.assertEqual(axis_size, 3)
self.assertEqual(in_batched, [False])
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), in_batched[0]

f_jvp = lambda x, tx: api.jvp(f, [x], [tx])

Expand All @@ -7223,14 +7226,14 @@ def f_linear(x): return 7. * x

@f_linear.def_vmap
def linear_rule(axis_size, in_batched, xs):
return [11. * xs], in_batched
return 11. * xs, in_batched[0]

@api.custom_vmap
def f_nonlinear(x): return jnp.sin(x)

@f_nonlinear.def_vmap
def nonlinear_rule(axis_size, in_batched, xs):
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), in_batched[0]

f_lin_jvp = lambda x, tx: api.jvp(f_linear, [x], [tx])
f_non_jvp = lambda x, tx: api.jvp(f_nonlinear, [x], [tx])
Expand Down Expand Up @@ -7267,7 +7270,7 @@ def f(x): return jnp.sin(x)

@f.def_vmap
def rule(axis_size, in_batched, xs):
return [cos_with_invalid_dataflow_jvp(xs)], in_batched
return cos_with_invalid_dataflow_jvp(xs), in_batched[0]

f_jvp = lambda x, tx: api.jvp(f, [x], [tx])
x, txs = jnp.array(1.), 2. + jnp.arange(3.)
Expand Down Expand Up @@ -7300,7 +7303,7 @@ def rule(axis_size, in_batched, xs):
self.assertEqual(in_batched, [in_batched_ref])
sz, = set([z.shape[0] for z in tree_util.tree_leaves(xs)])
self.assertEqual(axis_size, sz)
return [tree_cos(xs)], in_batched
return tree_cos(xs), in_batched[0]

y = f(x)
self.assertAllClose(y, tree_sin(x))
Expand All @@ -7324,7 +7327,7 @@ def rule(axis_size, in_batched, xs):
self.assertEqual(in_batched, [in_batched_ref])
sz, = set([z.shape[0] for z in tree_util.tree_leaves(xs)])
self.assertEqual(axis_size, sz)
return [tree_cos(xs)], in_batched
return tree_cos(xs), in_batched[0]

y = f(x)
self.assertAllClose(y, tree_sin(x))
Expand All @@ -7339,7 +7342,7 @@ def f(x): return jnp.sin(x)
def rule(axis_size, in_batched, xs):
self.assertEqual(in_batched, [True])
self.assertEqual(axis_size, xs.shape[0])
return [jnp.cos(xs)], in_batched
return jnp.cos(xs), in_batched[0]

x, xs = jnp.array(1.), jnp.arange(3)
self.assertAllClose(f(x), jit(f)(x))
Expand Down