Skip to content

Commit

Permalink
Support multiple inputs in flax lifted vjp/custom_vjp
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 469841272
  • Loading branch information
kho authored and Flax Authors committed Aug 24, 2022
1 parent 07e513f commit cbbf598
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 78 deletions.
30 changes: 15 additions & 15 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,9 @@ def _split_in_out_axes(xs: Mapping[CollectionFilter, Any]):


def _bwd_wrapper(treedef, bwd_fn, tangent):
vars_grad, inputs_grad = bwd_fn(tangent)
vars_grad, *inputs_grad = bwd_fn(tangent)
vars_grad = treedef.unflatten(vars_grad)
return inputs_grad, vars_grad
return (vars_grad, *inputs_grad)


def vjp(
Expand All @@ -362,13 +362,13 @@ def vjp(
Example::
def learn_scale(scope, x):
def learn_scale(scope, x, y):
p = scope.param('scale', nn.initializers.zeros, ())
return p * x
def f(scope, x):
y, bwd = lift.vjp(learn_scale, scope, x)
params_grad, x_grad = bwd(jnp.ones(y.shape))
return y, params_grad, x_grad
return p * x * y
def f(scope, x, y):
z, bwd = lift.vjp(learn_scale, scope, x, y)
params_grad, x_grad, y_grad = bwd(jnp.ones(z.shape))
return z, params_grad, x_grad, y_grad
Args:
fn: Function to be differentiated. Its arguments should be arrays, scalars,
Expand Down Expand Up @@ -1049,7 +1049,7 @@ def custom_vjp(fn: Callable[..., Any],
passed to `backward_fn`.
The `backward_fn` receives the nondiff arguments, residuals, and the output
tangents. It should return a tuple containing the input and variable tangents.
tangents. It should return a tuple containing the variable and input tangents.
Note that the vjp function returned by `lift.vjp` can be passed as residual
and used in the `backward_fn`. The scope is unavailable during the backward
Expand All @@ -1065,9 +1065,9 @@ def fwd(scope, x, features):
return y, vjp_fn
def bwd(features, vjp_fn, y_t):
input_t, params_t = vjp_fn(y_t)
params_t, *inputs_t = vjp_fn(y_t)
params_t = jax.tree_util.tree_map(jnp.sign, params_t)
return input_t, params_t
return (params_t, *inputs_t)
dense_sign_grad = lift.custom_vjp(
f, forward_fn=fwd, backward_fn=bwd, nondiff_argnums=(2,))
Expand All @@ -1080,8 +1080,8 @@ def bwd(features, vjp_fn, y_t):
`backward_fn`.
backward_fn: arguments are passed as (*nondiff_args, residuals, tangents)
The function should return a tuple containing the tangents for the
input arguments (except the scope and nondiff args) and the variable
tangents for the collections specified by `grad_vars`.
variable in the collections specified by `grad_vars` and the input
arguments (except the scope and nondiff args).
grad_vars: The collections for which a vjp will be computed
(default: "params").
nondiff_argnums: arguments for which no vjp is computed.
Expand Down Expand Up @@ -1116,10 +1116,10 @@ def f_bwd(*args):
nondiff_args = args[:-2]
res, g = args[-2:] # pylint: disable=unbalanced-tuple-unpacking
g_y, _ = g
input_t, var_t = backward_fn(*nondiff_args, res, g_y)
var_t, *inputs_t = backward_fn(*nondiff_args, res, g_y)
assert scopes_treedef is not None, 'backward called before forward?!'
var_t = tuple(scopes_treedef.flatten_up_to(var_t))
return var_t, input_t
return (var_t, *inputs_t)

f.defvjp(f_fwd, f_bwd)

Expand Down
28 changes: 14 additions & 14 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,16 +876,16 @@ def vjp(
class LearnScale(nn.Module):
@nn.compact
def __call__(self, x):
def __call__(self, x, y):
p = self.param('scale', nn.initializers.zeros, ())
return p * x
return p * x * y
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
y, bwd = nn.vjp(lambda mdl, x: mdl(x), LearnScale(), x)
params_grad, x_grad = bwd(jnp.ones(y.shape))
return y, params_grad, x_grad
def __call__(self, x, y):
z, bwd = nn.vjp(lambda mdl, x, y: mdl(x, y), LearnScale(), x, y)
params_grad, x_grad, y_grad = bwd(jnp.ones(z.shape))
return z, params_grad, x_grad, y_grad
Args:
fn: Function to be differentiated. Its arguments should be arrays, scalars,
Expand Down Expand Up @@ -1189,7 +1189,7 @@ def setup(self) -> None:
nn.Sequential([nn.Dense(11), nn.Dense(5)]),
nn.Dense(5),
]
@nn.compact
def __call__(self, x, index):
def head_fn(i):
Expand All @@ -1200,7 +1200,7 @@ def head_fn(i):
if self.is_mutable_collection('params'):
for branch in branches:
_ = branch(self, x)
return nn.switch(index, branches, self, x)
Args:
Expand Down Expand Up @@ -1249,7 +1249,7 @@ def custom_vjp(fn: Callable[..., Any],
passed to `backward_fn`.
The `backward_fn` receives the nondiff arguments, residuals, and the output
tangents. It should return a tuple containing the input and variable tangents.
tangents. It should return a tuple containing the variable and input tangents.
Note that the vjp function returned by `nn.vjp` can be passed as residual and
used in the `backward_fn`. The scope is unavailable during the backward pass.
Expand All @@ -1268,9 +1268,9 @@ def fwd(mdl, x):
return nn.vjp(f, mdl, x)
def bwd(vjp_fn, y_t):
input_t, params_t = vjp_fn(y_t)
params_t, *inputs_t = vjp_fn(y_t)
params_t = jax.tree_util.tree_map(jnp.sign, params_t)
return input_t, params_t
return (params_t, *inputs_t)
sign_grad = nn.custom_vjp(
f, forward_fn=fwd, backward_fn=bwd)
Expand All @@ -1287,9 +1287,9 @@ def bwd(vjp_fn, y_t):
``backward_fn``.
backward_fn: arguments are passed as
``(*nondiff_args, residuals, tangents)`` The function should return a
tuple containing the tangents for the input arguments (except the module
and nondiff args) and the variable tangents for the collections specified
by `grad_vars`.
tuple containing the tangents for the variable in the collections
specified by `grad_vars` and the input arguments (except the module and
nondiff args).
grad_vars: The collections for which a vjp will be computed
(default: "params").
nondiff_argnums: arguments for which no vjp is computed.
Expand Down
53 changes: 27 additions & 26 deletions tests/core/core_lift_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def f(scope):
def g(scopes, _):
scope, a = scopes
self.assertEqual(a.parent, scope)

lift.vmap(g, variable_axes={}, split_rngs={})((scope, a), jnp.ones((1,)))

init(f)(random.PRNGKey(0))

def test_undefined_param(self):
def f(scope):
dense = lift.vmap(nn.dense,
dense = lift.vmap(nn.dense,
in_axes=(0, None), out_axes=0,
variable_axes={'params': 0},
split_rngs={'params': True})
Expand Down Expand Up @@ -77,23 +77,24 @@ def f(scope, x):


def test_vjp(self):
def g(scope, x):
p = scope.param('test', nn.initializers.zeros, ())
def g(scope, x, y):
p = scope.param('test', nn.initializers.constant(0.5), ())
scope.variable('state', 'counter', lambda: 0)
return p * x
return p * x * y

def f(scope, x):
y, bwd = lift.vjp(g, scope, x)
params_grad, x_grad = bwd(jnp.ones(y.shape))
return params_grad, x_grad

x = jnp.ones((3,))
_, params = init(f)(random.PRNGKey(0), x)
x_grad, params_grad = apply(f)(params, x)
def f(scope, x, y):
z, bwd = lift.vjp(g, scope, x, y)
return bwd(jnp.ones(y.shape))

x = jnp.array([1., 2., 3.])
y = jnp.array([4., 5., 6.])
_, params = init(f)(random.PRNGKey(0), x, y)
params_grad, x_grad, y_grad = apply(f)(params, x, y)
self.assertEqual(params_grad, {
'params': FrozenDict({'test': 3.}),
'params': FrozenDict({'test': 32.}),
})
np.testing.assert_allclose(x_grad, 0. * x)
np.testing.assert_allclose(x_grad, [2., 2.5, 3.])
np.testing.assert_allclose(y_grad, [0.5, 1., 1.5])

def test_jvp(self):
def g(scope, x):
Expand All @@ -105,12 +106,12 @@ def f(scope, x):
vars_t = jax.tree_util.tree_map(jnp.ones_like, scope.variables().get('params', {}))
_, out_t = lift.jvp(g, scope, (x,), (jnp.zeros_like(x),), {'params': vars_t})
return out_t

x = jnp.ones((3,))
_, params = init(f)(random.PRNGKey(0), x)
y_t = apply(f)(params, x)
np.testing.assert_allclose(y_t, jnp.ones_like(x))

def test_while_loop(self):
def f(scope, x):
scope.param('inc', lambda _: 1)
Expand All @@ -137,7 +138,7 @@ def body_fn(scope, c):
self.assertEqual(c, 2 * x)
np.testing.assert_array_equal(vars['state']['rng_params'][0], vars['state']['rng_params'][1])
np.testing.assert_array_compare(operator.__ne__, vars['state']['rng_loop'][0], vars['state']['rng_loop'][1])

def test_cond(self):
def f(scope, x, pred):
scope.variable('state', 'true_count', lambda: 0)
Expand All @@ -149,16 +150,16 @@ def true_fn(scope, x):
def false_fn(scope, x):
scope.variable('state', 'false_count').value += 1
return -scope.child(nn.dense)(x, 2)

return lift.cond(pred, true_fn, false_fn, scope, x)

x = jnp.ones((1, 3))
y1, vars = init(f)(random.PRNGKey(0), x, True)
self.assertEqual(vars['state'].unfreeze(), {'true_count': 1, 'false_count': 0})
y2, vars = apply(f, mutable="state")(vars, x, False)
self.assertEqual(vars['state'].unfreeze(), {'true_count': 1, 'false_count': 1})
np.testing.assert_allclose(y1, -y2)

def test_switch(self):
def f(scope, x, index):
scope.variable('state', 'a_count', lambda: 0)
Expand All @@ -176,9 +177,9 @@ def b_fn(scope, x):
def c_fn(scope, x):
scope.variable('state', 'c_count').value += 1
return scope.child(nn.dense)(x, 2)

return lift.switch(index, [a_fn, b_fn, c_fn], scope, x)

x = jnp.ones((1, 3))
y1, vars = init(f)(random.PRNGKey(0), x, 0)
self.assertEqual(vars['state'].unfreeze(), {'a_count': 1, 'b_count': 0, 'c_count': 0})
Expand All @@ -190,14 +191,14 @@ def c_fn(scope, x):
vars = vars.copy(updates)
self.assertEqual(vars['state'].unfreeze(), {'a_count': 1, 'b_count': 1, 'c_count': 1})
np.testing.assert_allclose(y1, y3)

def test_subscope_var_aliasing(self):
def test(scope, x):
subscope = scope.push(name="a")
subscope.put_variable('state', 'x', 0.)
_ = lift.while_loop(
lambda scope, x: False,
lambda scope, x: x,
lambda scope, x: False,
lambda scope, x: x,
scope,
jnp.array(0, jnp.int32),
carry_variables=['state'],
Expand Down
4 changes: 2 additions & 2 deletions tests/core/design/core_custom_vjp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def fwd(scope, x, features):
def bwd(features, res, y_t):
del features
vjp_fn = res
input_t, params_t = vjp_fn(y_t)
params_t, *input_t = vjp_fn(y_t)
params_t = jax.tree_util.tree_map(jnp.sign, params_t)
return input_t, params_t
return (params_t, *input_t)

dense_custom_grad = lift.custom_vjp(
f, forward_fn=fwd, backward_fn=bwd, nondiff_argnums=(2,))
Expand Down
43 changes: 22 additions & 21 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,25 +1046,26 @@ def __call__(self, x):
def test_vjp(self):
class Bar(nn.Module):
@nn.compact
def __call__(self, x):
p = self.param('test', nn.initializers.zeros, ())
def __call__(self, x, y):
p = self.param('test', nn.initializers.constant(0.5), ())
self.variable('state', 'counter', lambda: 0)
return p * x
return p * x * y

class Foo(nn.Module):
@nn.compact
def __call__(self, x):
y, bwd = nn.vjp(Bar.__call__, Bar(), x)
params_grad, x_grad = bwd(jnp.ones(y.shape))
return params_grad, x_grad
def __call__(self, x, y):
z, bwd = nn.vjp(Bar.__call__, Bar(), x, y)
return bwd(jnp.ones(z.shape))

x = jnp.ones((3,))
params = Foo().init(random.PRNGKey(0), x)
x_grad, params_grad = Foo().apply(params, x)
x = jnp.array([1., 2., 3.])
y = jnp.array([4., 5., 6.])
params = Foo().init(random.PRNGKey(0), x, y)
params_grad, x_grad, y_grad = Foo().apply(params, x, y)
self.assertEqual(params_grad, {
'params': nn.FrozenDict({'test': 3.}),
'params': nn.FrozenDict({'test': 32.}),
})
np.testing.assert_allclose(x_grad, 0. * x)
np.testing.assert_allclose(x_grad, [2., 2.5, 3.])
np.testing.assert_allclose(y_grad, [0.5, 1., 1.5])

def test_jvp(self):
class Bar(nn.Module):
Expand Down Expand Up @@ -1134,9 +1135,9 @@ def fwd(mdl, x):
return nn.vjp(f, mdl, x)

def bwd(vjp_fn, y_t):
input_t, params_t = vjp_fn(y_t)
params_t, input_t = vjp_fn(y_t)
params_t = jax.tree_util.tree_map(jnp.sign, params_t)
return input_t, params_t
return params_t, input_t

sign_grad = nn.custom_vjp(
f, forward_fn=fwd, backward_fn=bwd)
Expand Down Expand Up @@ -1229,7 +1230,7 @@ def helper(self, x, ms):
@nn.jit
def __call__(self, x):
return self.helper(x, self.inners)

k = random.PRNGKey(0)
x = jnp.ones((2,))

Expand Down Expand Up @@ -1415,9 +1416,9 @@ def true_fn(mdl, x):
def false_fn(mdl, x):
mdl.variable('state', 'false_count').value += 1
return -nn.Dense(2, name='dense')(x)

return nn.cond(pred, true_fn, false_fn, self, x)

def test_switch(self):
class Foo(nn.Module):
@nn.compact
Expand All @@ -1436,9 +1437,9 @@ def b_fn(mdl, x):
def c_fn(mdl, x):
mdl.variable('state', 'c_count').value += 1
return nn.Dense(2, name='dense')(x)

return nn.switch(pred, [a_fn, b_fn, c_fn], self, x)

x = jnp.ones((1, 3))
foo = Foo()
y1, vars = foo.init_with_output(random.PRNGKey(0), x, 0)
Expand Down Expand Up @@ -1474,9 +1475,9 @@ def fn(mdl, x):
if self.is_mutable_collection('params'):
for branch in branches:
_ = branch(self, x)

return nn.switch(index, branches, self, x)

x = jnp.ones((1, 3))
foo = Foo()
y1, vars = foo.init_with_output(random.PRNGKey(0), x, 0)
Expand Down

0 comments on commit cbbf598

Please sign in to comment.