Skip to content

Commit

Permalink
Merge pull request #3479 from levskaya:fix_grad
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 581839415
  • Loading branch information
Flax Authors committed Nov 13, 2023
2 parents e959dbf + edad3e9 commit 2f6e5ff
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 9 deletions.
93 changes: 93 additions & 0 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,99 @@ def wrapper(vjp_vars, *args):
)(scope, *primals)


def value_and_grad(
fn: Callable[..., Any],
scope: Scope,
*primals,
has_aux: bool = False,
reduce_axes=(),
variables: CollectionFilter = True,
rngs: PRNGSequenceFilter = True,
) -> Union[Tuple[Any, Callable[..., Any]], Tuple[Any, Callable[..., Any], Any]]:
"""A limited lifted version of ``jax.value_and_grad``.
See ``jax.value_and_grad`` for the unlifted reverse mode gradient.
Note that for this convenience function, gradients are only calculated for
the function inputs (all function inputs), and not with respect to any scope
variables. The target function must return a scalar-valued output.
Example::
def learn_scale(scope, x, y):
p = scope.param('scale', nn.initializers.zeros_init(), ())
return p * x * y
def f(scope, x, y):
z, x_grad, y_grad = lift.value_and_grad(learn_scale, scope, x, y)
return z, x_grad, y_grad
Args:
fn: Function to be differentiated. Its arguments should be arrays, scalars,
or standard Python containers of arrays or scalars. It should return an
array, scalar, or standard Python container of arrays or scalars. It will
receive the scope and primals as arguments.
scope: The scope of which the variables will be differentiated.
*primals: A sequence of primal values at which the Jacobian of ``fn``
should be evaluated. The length of ``primals`` should be equal to the
number of positional parameters to ``fn``. Each primal value should be a
tuple of arrays, scalar, or standard Python containers thereof.
has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fn`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
VJP will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will
create a VJP function that sums over the batch while ``vjp(f, *args)``
will create a per-example VJP.
variables: other variables collections that are available inside `fn` but
do not receive a cotangent.
rngs: the prngs that are available inside `fn`.
Returns:
If ``has_aux`` is ``False``, returns a ``(primals_out, grads)`` pair, where
``primals_out`` is ``fn(*primals)``.
If ``has_aux`` is ``True``, returns a
``(primals_out, aux, grads)`` tuple where ``aux`` is the auxiliary data
returned by ``fn``.
"""

def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args):
@functools.wraps(fn)
def wrapper(*args):
scope = scope_fn(variable_groups, rng_groups)
if has_aux:
y, aux = fn(scope, *args)
else:
y = fn(scope, *args)
aux = ()
return y, (aux, repack_fn(scope))

y, bwd, (aux, out_vars) = jax.vjp(
wrapper,
*args,
has_aux=True,
reduce_axes=reduce_axes,
)

inputs_grad = bwd(jax.numpy.ones_like(y))

if has_aux:
return (y, aux, inputs_grad), out_vars
else:
return (y, inputs_grad), out_vars

return pack(
inner,
(variables,),
(variables,),
(rngs,),
name='value_and_grad',
enable_kwargs=False,
)(scope, *primals)


def jvp(
fn: Callable[..., Any],
scope: Scope,
Expand Down
15 changes: 6 additions & 9 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,36 +1203,33 @@ def __call__(self, x, y):
returned by ``fn``.
"""

vjp_partial = functools.partial(
vjp,
fn,
grad_partial = functools.partial(
lift_direct_transform,
lift.value_and_grad,
(fn,),
mdl,
*primals,
has_aux=has_aux,
reduce_axes=reduce_axes,
vjp_variables=False,
variables=variables,
rngs=rngs,
multi_scope=True,
)

if has_aux:
out, vjp_fun, aux = vjp_partial()
out, aux, argument_grads = grad_partial()
if out.shape != ():
raise ValueError(
'grad can only work on functions with '
f'scalar-valued outputs. out shape={out.shape}'
)
_, *argument_grads = vjp_fun(jax.numpy.ones_like(out))
return (out, aux), argument_grads
else:
out, vjp_fun = vjp_partial()
out, argument_grads = grad_partial()
if out.shape != ():
raise ValueError(
'grad can only work on functions with '
f'scalar-valued outputs. out shape={out.shape}'
)
_, *argument_grads = vjp_fun(jax.numpy.ones_like(out))
return out, argument_grads


Expand Down
29 changes: 29 additions & 0 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2136,6 +2136,35 @@ def comparison_fn(x, y):
self.assertTrue(tree_allclose(jax.grad(comparison_fn, 0)(x, y), x_grad))
self.assertTrue(tree_allclose(jax.grad(comparison_fn, 1)(x, y), y_grad))

def test_value_and_grad_multiscope_adopted(self):
class Foo(nn.Module):
bar: nn.Module
qup: nn.Module

@nn.compact
def __call__(self, x, y):
def fn(self, x, y):
delta = y - self.bar(self.qup(x))
return jnp.sum(delta**2)

z, (x_grad, y_grad) = nn.value_and_grad(fn, self, x, y)
return z, x_grad, y_grad

x = random.uniform(random.key(1), (4,))
y = random.uniform(random.key(2), (4,))
vs = Foo(bar=nn.Dense(4), qup=nn.Dense(4)).init(random.key(0), x, y)
z, x_grad, y_grad = Foo(bar=nn.Dense(4), qup=nn.Dense(4)).apply(vs, x, y)

def comparison_fn(x, y):
w1 = vs['params']['qup']['kernel']
w2 = vs['params']['bar']['kernel']
delta = y - jnp.dot(jnp.dot(x, w1), w2)
return jnp.sum(delta**2)

self.assertTrue(tree_allclose(comparison_fn(x, y), z))
self.assertTrue(tree_allclose(jax.grad(comparison_fn, 0)(x, y), x_grad))
self.assertTrue(tree_allclose(jax.grad(comparison_fn, 1)(x, y), y_grad))


if __name__ == '__main__':
absltest.main()

0 comments on commit 2f6e5ff

Please sign in to comment.