Skip to content

Commit

Permalink
Document how jax.hessian and pytrees interact. (#2705)
Browse files Browse the repository at this point in the history
* Document how jax.hessian and pytrees interact.
  • Loading branch information
hawkinsp authored Apr 28, 2020
1 parent e599a25 commit ae6a3fe
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,10 @@ def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
"""Hessian of ``fun`` as a dense array.
Args:
fun: Function whose Hessian is to be computed.
fun: Function whose Hessian is to be computed. Its arguments at positions
specified by ``argnums`` should be arrays, scalars, or standard Python
containers thereof. It should return arrays, scalars, or standard Python
containers thereof.
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default ``0``).
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
Expand All @@ -584,6 +587,28 @@ def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
[[ 6., -2.],
[ -2., -480.]]
:py:func:`hessian` is a generalization of the usual definition of the Hessian
that supports Python trees as inputs and outputs. The structure of the output
is a composition of the output Python tree structure with two nested copies
of the input Python tree structure. The Python tree structure of the output
captures the block-sparsity structure of the generated Hessian viewed in its
usual matrix form. For example:
>>> f = lambda inp: {"c": jnp.power(inp["a"], inp["b"])}
>>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.}))
{'c': {'a': {'a': DeviceArray([[[ 2., 0.], [ 0., 0.]],
[[ 0., 0.], [ 0., 12.]]], dtype=float32),
'b': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]],
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)},
'b': {'a': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]],
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32),
'b': DeviceArray([[[0. , 0. ], [0. , 0. ]],
[[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
If the usual dense matrix form of the Hessian is desired, one can
flatten and concatenate the arguments into a single 1D array before
computing the Hessian.
"""
return jacfwd(jacrev(fun, argnums, holomorphic), argnums, holomorphic)

Expand Down

0 comments on commit ae6a3fe

Please sign in to comment.