From edad3e9edd9dfc65de231331f583924e8c1a45da Mon Sep 17 00:00:00 2001 From: Anselm Levskaya Date: Sun, 12 Nov 2023 20:21:05 -0800 Subject: [PATCH] fix nn.value_and_grad by implementing directly in core --- flax/core/lift.py | 93 ++++++++++++++++++++++++++++ flax/linen/transforms.py | 15 ++--- tests/linen/linen_transforms_test.py | 29 +++++++++ 3 files changed, 128 insertions(+), 9 deletions(-) diff --git a/flax/core/lift.py b/flax/core/lift.py index 2f9627d54..3f2cb90f5 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -503,6 +503,99 @@ def wrapper(vjp_vars, *args): )(scope, *primals) +def value_and_grad( + fn: Callable[..., Any], + scope: Scope, + *primals, + has_aux: bool = False, + reduce_axes=(), + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, +) -> Union[Tuple[Any, Callable[..., Any]], Tuple[Any, Callable[..., Any], Any]]: + """A limited lifted version of ``jax.value_and_grad``. + + See ``jax.value_and_grad`` for the unlifted reverse mode gradient. + + Note that for this convenience function, gradients are only calculated for + the function inputs (all function inputs), and not with respect to any scope + variables. The target function must return a scalar-valued output. + + Example:: + + def learn_scale(scope, x, y): + p = scope.param('scale', nn.initializers.zeros_init(), ()) + return p * x * y + def f(scope, x, y): + z, x_grad, y_grad = lift.value_and_grad(learn_scale, scope, x, y) + return z, x_grad, y_grad + + Args: + fn: Function to be differentiated. Its arguments should be arrays, scalars, + or standard Python containers of arrays or scalars. It should return an + array, scalar, or standard Python container of arrays or scalars. It will + receive the scope and primals as arguments. + scope: The scope of which the variables will be differentiated. + *primals: A sequence of primal values at which the Jacobian of ``fn`` + should be evaluated. The length of ``primals`` should be equal to the + number of positional parameters to ``fn``. Each primal value should be a + tuple of arrays, scalar, or standard Python containers thereof. + has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default False. + reduce_axes: Optional, tuple of axis names. If an axis is listed here, and + ``fn`` implicitly broadcasts a value over that axis, the backward pass + will perform a ``psum`` of the corresponding gradient. Otherwise, the + VJP will be per-example over named axes. For example, if ``'batch'`` + is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will + create a VJP function that sums over the batch while ``vjp(f, *args)`` + will create a per-example VJP. + variables: other variables collections that are available inside `fn` but + do not receive a cotangent. + rngs: the prngs that are available inside `fn`. + + Returns: + If ``has_aux`` is ``False``, returns a ``(primals_out, grads)`` pair, where + ``primals_out`` is ``fn(*primals)``. + If ``has_aux`` is ``True``, returns a + ``(primals_out, aux, grads)`` tuple where ``aux`` is the auxiliary data + returned by ``fn``. + """ + + def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): + @functools.wraps(fn) + def wrapper(*args): + scope = scope_fn(variable_groups, rng_groups) + if has_aux: + y, aux = fn(scope, *args) + else: + y = fn(scope, *args) + aux = () + return y, (aux, repack_fn(scope)) + + y, bwd, (aux, out_vars) = jax.vjp( + wrapper, + *args, + has_aux=True, + reduce_axes=reduce_axes, + ) + + inputs_grad = bwd(jax.numpy.ones_like(y)) + + if has_aux: + return (y, aux, inputs_grad), out_vars + else: + return (y, inputs_grad), out_vars + + return pack( + inner, + (variables,), + (variables,), + (rngs,), + name='value_and_grad', + enable_kwargs=False, + )(scope, *primals) + + def jvp( fn: Callable[..., Any], scope: Scope, diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 6d3ff2b0b..b7cfafb6b 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -1203,36 +1203,33 @@ def __call__(self, x, y): returned by ``fn``. """ - vjp_partial = functools.partial( - vjp, - fn, + grad_partial = functools.partial( + lift_direct_transform, + lift.value_and_grad, + (fn,), mdl, *primals, has_aux=has_aux, reduce_axes=reduce_axes, - vjp_variables=False, variables=variables, rngs=rngs, - multi_scope=True, ) if has_aux: - out, vjp_fun, aux = vjp_partial() + out, aux, argument_grads = grad_partial() if out.shape != (): raise ValueError( 'grad can only work on functions with ' f'scalar-valued outputs. out shape={out.shape}' ) - _, *argument_grads = vjp_fun(jax.numpy.ones_like(out)) return (out, aux), argument_grads else: - out, vjp_fun = vjp_partial() + out, argument_grads = grad_partial() if out.shape != (): raise ValueError( 'grad can only work on functions with ' f'scalar-valued outputs. out shape={out.shape}' ) - _, *argument_grads = vjp_fun(jax.numpy.ones_like(out)) return out, argument_grads diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index a43264775..e22ef97a2 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -2136,6 +2136,35 @@ def comparison_fn(x, y): self.assertTrue(tree_allclose(jax.grad(comparison_fn, 0)(x, y), x_grad)) self.assertTrue(tree_allclose(jax.grad(comparison_fn, 1)(x, y), y_grad)) + def test_value_and_grad_multiscope_adopted(self): + class Foo(nn.Module): + bar: nn.Module + qup: nn.Module + + @nn.compact + def __call__(self, x, y): + def fn(self, x, y): + delta = y - self.bar(self.qup(x)) + return jnp.sum(delta**2) + + z, (x_grad, y_grad) = nn.value_and_grad(fn, self, x, y) + return z, x_grad, y_grad + + x = random.uniform(random.key(1), (4,)) + y = random.uniform(random.key(2), (4,)) + vs = Foo(bar=nn.Dense(4), qup=nn.Dense(4)).init(random.key(0), x, y) + z, x_grad, y_grad = Foo(bar=nn.Dense(4), qup=nn.Dense(4)).apply(vs, x, y) + + def comparison_fn(x, y): + w1 = vs['params']['qup']['kernel'] + w2 = vs['params']['bar']['kernel'] + delta = y - jnp.dot(jnp.dot(x, w1), w2) + return jnp.sum(delta**2) + + self.assertTrue(tree_allclose(comparison_fn(x, y), z)) + self.assertTrue(tree_allclose(jax.grad(comparison_fn, 0)(x, y), x_grad)) + self.assertTrue(tree_allclose(jax.grad(comparison_fn, 1)(x, y), y_grad)) + if __name__ == '__main__': absltest.main()